diff --git a/README.md b/README.md
index 56c5bc05..c8e065e8 100644
--- a/README.md
+++ b/README.md
@@ -57,6 +57,25 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
---
+## 环境最低要求
+
+想顺利运行本代码,请按照以下的最低要求进行配置:
++ Python版本: >= 3.8.5, < 3.11
++ Cuda版本: >= 11.7, 且能顺利安装Python
+
+如果想要顺利在GPU运行本地模型(int4版本),你至少需要以下的硬件配置:
+
++ chatglm2-6b & LLaMA-7B 最低显存要求: 7GB 推荐显卡: RTX 3060, RTX 2060
++ LLaMA-13B 最低显存要求: 11GB 推荐显卡: RTX 2060 12GB, RTX3060 12GB, RTX3080, RTXA2000
++ Qwen-14B-Chat 最低显存要求: 13GB 推荐显卡: RTX 3090
++ LLaMA-30B 最低显存要求: 22GB 推荐显卡:RTX A5000,RTX 3090,RTX 4090,RTX 6000,Tesla V100,RTX Tesla P40
++ LLaMA-65B 最低显存要求: 40GB 推荐显卡:A100,A40,A6000
+
+如果是int8 则显存x1.5 fp16 x2.5的要求
+如:使用fp16 推理Qwen-7B-Chat 模型 则需要使用16GB显存。
+
+以上仅为估算,实际情况以nvidia-smi占用为准。
+
## 变更日志
参见 [版本更新日志](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)。
@@ -112,27 +131,29 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
-- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
+- [Qwen/Qwen-7B-Chat/Qwen-14B-Chat](https://huggingface.co/Qwen/)
- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta)
- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and others
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
- [all models of OpenOrca](https://huggingface.co/Open-Orca)
- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2)
- [VMware's OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
+- [baichuan2-7b/baichuan2-13b](https://huggingface.co/baichuan-inc)
- 任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)
- 在以上模型基础上训练的任何 [Peft](https://github.com/huggingface/peft) 适配器。为了激活,模型路径中必须有 `peft` 。注意:如果加载多个peft模型,你可以通过在任何模型工作器中设置环境变量 `PEFT_SHARE_BASE_WEIGHTS=true` 来使它们共享基础模型的权重。
以上模型支持列表可能随 [FastChat](https://github.com/lm-sys/FastChat) 更新而持续更新,可参考 [FastChat 已支持模型列表](https://github.com/lm-sys/FastChat/blob/main/docs/model_support.md)。
-
除本地模型外,本项目也支持直接接入 OpenAI API、智谱AI等在线模型,具体设置可参考 `configs/model_configs.py.example` 中的 `llm_model_dict` 的配置信息。
-在线 LLM 模型目前已支持:
+在线 LLM 模型目前已支持:
+
- [ChatGPT](https://api.openai.com)
- [智谱AI](http://open.bigmodel.cn)
- [MiniMax](https://api.minimax.chat)
- [讯飞星火](https://xinghuo.xfyun.cn)
- [百度千帆](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
+- [阿里云通义千问](https://dashscope.aliyun.com/)
项目中默认使用的 LLM 类型为 `THUDM/chatglm2-6b`,如需使用其他 LLM 类型,请在 [configs/model_config.py] 中对 `llm_model_dict` 和 `LLM_MODEL` 进行修改。
@@ -157,9 +178,11 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
+- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
+- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-large-zh)
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
-项目中默认使用的 Embedding 类型为 `moka-ai/m3e-base`,如需使用其他 Embedding 类型,请在 [configs/model_config.py] 中对 `embedding_model_dict` 和 `EMBEDDING_MODEL` 进行修改。
+项目中默认使用的 Embedding 类型为 `sensenova/piccolo-base-zh`,如需使用其他 Embedding 类型,请在 [configs/model_config.py] 中对 `embedding_model_dict` 和 `EMBEDDING_MODEL` 进行修改。
---
@@ -187,15 +210,27 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
关于如何使用自定义分词器和贡献自己的分词器,可以参考[Text Splitter 贡献说明](docs/splitter.md)。
+## Agent生态
+### 基础的Agent
+在本版本中,我们实现了一个简单的基于OpenAI的React的Agent模型,目前,经过我们测试,仅有以下两个模型支持:
++ OpenAI GPT4
++ ChatGLM2-130B
+
+目前版本的Agent仍然需要对提示词进行大量调试,调试位置
+
+### 构建自己的Agent工具
+
+详见 [自定义Agent说明](docs/自定义Agent.md)
+
## Docker 部署
-🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
+🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5)`
```shell
-docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3
+docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5
```
-- 该版本镜像大小 `35.3GB`,使用 `v0.2.3`,以 `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` 为基础镜像
+- 该版本镜像大小 `35.3GB`,使用 `v0.2.5`,以 `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` 为基础镜像
- 该版本内置两个 `embedding` 模型:`m3e-large`,`text2vec-bge-large-chinese`,默认启用后者,内置 `chatglm2-6b-32k`
- 该版本目标为方便一键部署使用,请确保您已经在Linux发行版上安装了NVIDIA驱动程序
- 请注意,您不需要在主机系统上安装CUDA工具包,但需要安装 `NVIDIA Driver` 以及 `NVIDIA Container Toolkit`,请参考[安装指南](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
@@ -391,8 +426,8 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
- [X] .csv
- [ ] .xlsx
- [ ] 分词及召回
- - [ ] 接入不同类型 TextSplitter
- - [ ] 优化依据中文标点符号设计的 ChineseTextSplitter
+ - [X] 接入不同类型 TextSplitter
+ - [X] 优化依据中文标点符号设计的 ChineseTextSplitter
- [ ] 重新实现上下文拼接召回
- [ ] 本地网页接入
- [ ] SQL 接入
@@ -400,13 +435,17 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
- [X] 搜索引擎接入
- [X] Bing 搜索
- [X] DuckDuckGo 搜索
- - [ ] Agent 实现
+ - [X] Agent 实现
+ - [X] 基础React形式的Agent实现,包括调用计算器等
+ - [X] Langchain 自带的Agent实现和调用
+ - [ ] 更多模型的Agent支持
+ - [ ] 更多工具
- [X] LLM 模型接入
- [X] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm
- - [ ] 支持 ChatGLM API 等 LLM API 的接入
+ - [X] 支持 ChatGLM API 等 LLM API 的接入
- [X] Embedding 模型接入
- [X] 支持调用 HuggingFace 中各开源 Emebdding 模型
- - [ ] 支持 OpenAI Embedding API 等 Embedding API 的接入
+ - [X] 支持 OpenAI Embedding API 等 Embedding API 的接入
- [X] 基于 FastAPI 的 API 方式调用
- [X] Web UI
- [X] 基于 Streamlit 的 Web UI
@@ -417,4 +456,12 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
-🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
+🎉 langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
+
+
+## 关注我们
+
+
+🎉 langchain-Chatchat 项目官方公众号,欢迎扫码关注。
+
+
diff --git a/README_en.md b/README_en.md
index c64d7c95..c7771ffc 100644
--- a/README_en.md
+++ b/README_en.md
@@ -56,6 +56,25 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
---
+## Environment Minimum Requirements
+
+To run this code smoothly, please configure it according to the following minimum requirements:
++ Python version: >= 3.8.5, < 3.11
++ Cuda version: >= 11.7, with Python installed.
+
+If you want to run the native model (int4 version) on the GPU without problems, you need at least the following hardware configuration.
+
++ chatglm2-6b & LLaMA-7B Minimum RAM requirement: 7GB Recommended graphics cards: RTX 3060, RTX 2060
++ LLaMA-13B Minimum graphics memory requirement: 11GB Recommended cards: RTX 2060 12GB, RTX3060 12GB, RTX3080, RTXA2000
++ Qwen-14B-Chat Minimum memory requirement: 13GB Recommended graphics card: RTX 3090
++ LLaMA-30B Minimum Memory Requirement: 22GB Recommended Cards: RTX A5000,RTX 3090,RTX 4090,RTX 6000,Tesla V100,RTX Tesla P40
++ Minimum memory requirement for LLaMA-65B: 40GB Recommended cards: A100,A40,A6000
+
+If int8 then memory x1.5 fp16 x2.5 requirement.
+For example: using fp16 to reason about the Qwen-7B-Chat model requires 16GB of video memory.
+
+The above is only an estimate, the actual situation is based on nvidia-smi occupancy.
+
## Change Log
plese refer to [version change log](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)
@@ -105,18 +124,31 @@ The project use [FastChat](https://github.com/lm-sys/FastChat) to provide the AP
- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
-- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
+- [Qwen/Qwen-7B-Chat/Qwen-14B-Chat](https://huggingface.co/Qwen/)
- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta)
- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and other models of FlagAlpha
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
- [all models of OpenOrca](https://huggingface.co/Open-Orca)
- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2)
+- [baichuan2-7b/baichuan2-13b](https://huggingface.co/baichuan-inc)
- [VMware's OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
* Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)(任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b))
* Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a model above. To activate, must have `peft` in the model path. Note: If loading multiple peft models, you can have them share the base model weights by setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model worker.
-Please refer to `llm_model_dict` in `configs.model_configs.py.example` to invoke OpenAI API.
+
+The above model support list may be updated continuously as [FastChat](https://github.com/lm-sys/FastChat) is updated, see [FastChat Supported Models List](https://github.com/lm-sys/FastChat/blob/main /docs/model_support.md).
+In addition to local models, this project also supports direct access to online models such as OpenAI API, Wisdom Spectrum AI, etc. For specific settings, please refer to the configuration information of `llm_model_dict` in `configs/model_configs.py.example`.
+Online LLM models are currently supported:
+
+- [ChatGPT](https://api.openai.com)
+- [Smart Spectrum AI](http://open.bigmodel.cn)
+- [MiniMax](https://api.minimax.chat)
+- [Xunfei Starfire](https://xinghuo.xfyun.cn)
+- [Baidu Qianfan](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
+- [Aliyun Tongyi Qianqian](https://dashscope.aliyun.com/)
+
+The default LLM type used in the project is `THUDM/chatglm2-6b`, if you need to use other LLM types, please modify `llm_model_dict` and `LLM_MODEL` in [configs/model_config.py].
### Supported Embedding models
@@ -129,6 +161,8 @@ Following models are tested by developers with Embedding class of [HuggingFace](
- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)
- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)
- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct)
+- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
+- [sensenova/piccolo-large-zh](https://huggingface.co/sensenova/piccolo-large-zh)
- [shibing624/text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence)
- [shibing624/text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase)
- [shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
@@ -137,16 +171,24 @@ Following models are tested by developers with Embedding class of [HuggingFace](
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
+- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
+- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-large-zh)
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
+The default Embedding type used in the project is `sensenova/piccolo-base-zh`, if you want to use other Embedding types, please modify `embedding_model_dict` and `embedding_model_dict` and `embedding_model_dict` in [configs/model_config.py]. MODEL` in [configs/model_config.py].
+
+### Build your own Agent tool!
+
+See [Custom Agent Instructions](docs/自定义Agent.md) for details.
+
---
## Docker Deployment
-🐳 Docker image path: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0)`
+🐳 Docker image path: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5)`
```shell
-docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0
+docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5
```
- The image size of this version is `33.9GB`, using `v0.2.0`, with `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` as the base image
@@ -328,9 +370,9 @@ Please refer to [FAQ](docs/FAQ.md)
- [ ] Structured documents
- [X] .csv
- [ ] .xlsx
- - [ ] TextSplitter and Retriever
- - [x] multiple TextSplitter
- - [x] ChineseTextSplitter
+ - [] TextSplitter and Retriever
+ - [X] multiple TextSplitter
+ - [X] ChineseTextSplitter
- [ ] Reconstructed Context Retriever
- [ ] Webpage
- [ ] SQL
@@ -338,7 +380,11 @@ Please refer to [FAQ](docs/FAQ.md)
- [X] Search Engines
- [X] Bing
- [X] DuckDuckGo
- - [ ] Agent
+ - [X] Agent
+ - [X] Agent implementation in the form of basic React, including calls to calculators, etc.
+ - [X] Langchain's own Agent implementation and calls
+ - [ ] More Agent support for models
+ - [ ] More tools
- [X] LLM Models
- [X] [FastChat](https://github.com/lm-sys/fastchat) -based LLM Models
- [ ] Mutiply Remote LLM API
@@ -348,3 +394,16 @@ Please refer to [FAQ](docs/FAQ.md)
- [X] FastAPI-based API
- [X] Web UI
- [X] Streamlit -based Web UI
+
+---
+
+## Wechat Group
+
+
+
+🎉 langchain-Chatchat project WeChat exchange group, if you are also interested in this project, welcome to join the group chat to participate in the discussion and exchange.
+
+## Follow us
+
+
+🎉 langchain-Chatchat project official public number, welcome to scan the code to follow.
\ No newline at end of file
diff --git a/chains/llmchain_with_history.py b/chains/llmchain_with_history.py
index 3d360422..9707c00c 100644
--- a/chains/llmchain_with_history.py
+++ b/chains/llmchain_with_history.py
@@ -1,19 +1,12 @@
-from langchain.chat_models import ChatOpenAI
-from configs.model_config import llm_model_dict, LLM_MODEL
-from langchain import LLMChain
+from server.utils import get_ChatOpenAI
+from configs.model_config import LLM_MODEL, TEMPERATURE
+from langchain.chains import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
-model = ChatOpenAI(
- streaming=True,
- verbose=True,
- # callbacks=[callback],
- openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
- openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL
-)
+model = get_ChatOpenAI(model_name=LLM_MODEL, temperature=TEMPERATURE)
human_prompt = "{input}"
diff --git a/configs/__init__.py b/configs/__init__.py
index 41169e8b..412e33dd 100644
--- a/configs/__init__.py
+++ b/configs/__init__.py
@@ -1,4 +1,8 @@
+from .basic_config import *
from .model_config import *
+from .kb_config import *
from .server_config import *
+from .prompt_config import *
-VERSION = "v0.2.4"
+
+VERSION = "v0.2.5"
diff --git a/configs/basic_config.py.example b/configs/basic_config.py.example
new file mode 100644
index 00000000..6bd8c8d2
--- /dev/null
+++ b/configs/basic_config.py.example
@@ -0,0 +1,22 @@
+import logging
+import os
+import langchain
+
+# 是否显示详细日志
+log_verbose = False
+langchain.verbose = False
+
+
+# 通常情况下不需要更改以下内容
+
+# 日志格式
+LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+logging.basicConfig(format=LOG_FORMAT)
+
+
+# 日志存储路径
+LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
+if not os.path.exists(LOG_PATH):
+ os.mkdir(LOG_PATH)
diff --git a/configs/kb_config.py.exmaple b/configs/kb_config.py.exmaple
new file mode 100644
index 00000000..af68a74d
--- /dev/null
+++ b/configs/kb_config.py.exmaple
@@ -0,0 +1,99 @@
+import os
+
+
+# 默认向量库类型。可选:faiss, milvus, pg.
+DEFAULT_VS_TYPE = "faiss"
+
+# 缓存向量库数量(针对FAISS)
+CACHED_VS_NUM = 1
+
+# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
+CHUNK_SIZE = 250
+
+# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
+OVERLAP_SIZE = 50
+
+# 知识库匹配向量数量
+VECTOR_SEARCH_TOP_K = 3
+
+# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
+SCORE_THRESHOLD = 1
+
+# 搜索引擎匹配结题数量
+SEARCH_ENGINE_TOP_K = 3
+
+
+# Bing 搜索必备变量
+# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
+# 具体申请方式请见
+# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
+# 使用python创建bing api 搜索实例详见:
+# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
+BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
+# 注意不是bing Webmaster Tools的api key,
+
+# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
+# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
+BING_SUBSCRIPTION_KEY = ""
+
+# 是否开启中文标题加强,以及标题增强的相关配置
+# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
+# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
+ZH_TITLE_ENHANCE = False
+
+
+# 通常情况下不需要更改以下内容
+
+# 知识库默认存储路径
+KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
+if not os.path.exists(KB_ROOT_PATH):
+ os.mkdir(KB_ROOT_PATH)
+
+# 数据库默认存储路径。
+# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
+DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
+SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
+
+# 可选向量库类型及对应配置
+kbs_config = {
+ "faiss": {
+ },
+ "milvus": {
+ "host": "127.0.0.1",
+ "port": "19530",
+ "user": "",
+ "password": "",
+ "secure": False,
+ },
+ "pg": {
+ "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
+ }
+}
+
+# TextSplitter配置项,如果你不明白其中的含义,就不要修改。
+text_splitter_dict = {
+ "ChineseRecursiveTextSplitter": {
+ "source": "huggingface", ## 选择tiktoken则使用openai的方法
+ "tokenizer_name_or_path": "gpt2",
+ },
+ "SpacyTextSplitter": {
+ "source": "huggingface",
+ "tokenizer_name_or_path": "",
+ },
+ "RecursiveCharacterTextSplitter": {
+ "source": "tiktoken",
+ "tokenizer_name_or_path": "cl100k_base",
+ },
+ "MarkdownHeaderTextSplitter": {
+ "headers_to_split_on":
+ [
+ ("#", "head1"),
+ ("##", "head2"),
+ ("###", "head3"),
+ ("####", "head4"),
+ ]
+ },
+}
+
+# TEXT_SPLITTER 名称
+TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"
diff --git a/configs/model_config.py.example b/configs/model_config.py.example
index ad94fca7..f5e84650 100644
--- a/configs/model_config.py.example
+++ b/configs/model_config.py.example
@@ -1,63 +1,115 @@
import os
-import logging
-# 日志格式
-LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
-logger = logging.getLogger()
-logger.setLevel(logging.INFO)
-logging.basicConfig(format=LOG_FORMAT)
-# 是否显示详细日志
-log_verbose = False
-# 在以下字典中修改属性值,以指定本地embedding模型存储位置
-# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
-# 此处请写绝对路径
-embedding_model_dict = {
- "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
- "ernie-base": "nghuyong/ernie-3.0-base-zh",
- "text2vec-base": "shibing624/text2vec-base-chinese",
- "text2vec": "GanymedeNil/text2vec-large-chinese",
- "text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
- "text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
- "text2vec-multilingual": "shibing624/text2vec-base-multilingual",
- "text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
- "m3e-small": "moka-ai/m3e-small",
- "m3e-base": "moka-ai/m3e-base",
- "m3e-large": "moka-ai/m3e-large",
- "bge-small-zh": "BAAI/bge-small-zh",
- "bge-base-zh": "BAAI/bge-base-zh",
- "bge-large-zh": "BAAI/bge-large-zh",
- "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
- "piccolo-base-zh": "sensenova/piccolo-base-zh",
- "piccolo-large-zh": "sensenova/piccolo-large-zh",
- "text-embedding-ada-002": os.environ.get("OPENAI_API_KEY")
+# 可以指定一个绝对路径,统一存放所有的Embedding和LLM模型。
+# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录
+MODEL_ROOT_PATH = ""
+
+# 在以下字典中修改属性值,以指定本地embedding模型存储位置。支持3种设置方法:
+# 1、将对应的值修改为模型绝对路径
+# 2、不修改此处的值(以 text2vec 为例):
+# 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录:
+# - text2vec
+# - GanymedeNil/text2vec-large-chinese
+# - text2vec-large-chinese
+# 2.2 如果以上本地路径不存在,则使用huggingface模型
+MODEL_PATH = {
+ "embed_model": {
+ "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
+ "ernie-base": "nghuyong/ernie-3.0-base-zh",
+ "text2vec-base": "shibing624/text2vec-base-chinese",
+ "text2vec": "GanymedeNil/text2vec-large-chinese",
+ "text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
+ "text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
+ "text2vec-multilingual": "shibing624/text2vec-base-multilingual",
+ "text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
+ "m3e-small": "moka-ai/m3e-small",
+ "m3e-base": "moka-ai/m3e-base",
+ "m3e-large": "moka-ai/m3e-large",
+ "bge-small-zh": "BAAI/bge-small-zh",
+ "bge-base-zh": "BAAI/bge-base-zh",
+ "bge-large-zh": "BAAI/bge-large-zh",
+ "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
+ "piccolo-base-zh": "sensenova/piccolo-base-zh",
+ "piccolo-large-zh": "sensenova/piccolo-large-zh",
+ "text-embedding-ada-002": "your OPENAI_API_KEY",
+ },
+ # TODO: add all supported llm models
+ "llm_model": {
+ # 以下部分模型并未完全测试,仅根据fastchat和vllm模型的模型列表推定支持
+ "chatglm-6b": "THUDM/chatglm-6b",
+ "chatglm2-6b": "THUDM/chatglm2-6b",
+ "chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
+ "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
+
+ "baichuan2-13b": "baichuan-inc/Baichuan-13B-Chat",
+ "baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat",
+
+ "baichuan-7b": "baichuan-inc/Baichuan-7B",
+ "baichuan-13b": "baichuan-inc/Baichuan-13B",
+ 'baichuan-13b-chat':'baichuan-inc/Baichuan-13B-Chat',
+
+ "aquila-7b":"BAAI/Aquila-7B",
+ "aquilachat-7b":"BAAI/AquilaChat-7B",
+
+ "internlm-7b":"internlm/internlm-7b",
+ "internlm-chat-7b":"internlm/internlm-chat-7b",
+
+ "falcon-7b":"tiiuae/falcon-7b",
+ "falcon-40b":"tiiuae/falcon-40b",
+ "falcon-rw-7b":"tiiuae/falcon-rw-7b",
+
+ "gpt2":"gpt2",
+ "gpt2-xl":"gpt2-xl",
+
+ "gpt-j-6b":"EleutherAI/gpt-j-6b",
+ "gpt4all-j":"nomic-ai/gpt4all-j",
+ "gpt-neox-20b":"EleutherAI/gpt-neox-20b",
+ "pythia-12b":"EleutherAI/pythia-12b",
+ "oasst-sft-4-pythia-12b-epoch-3.5":"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
+ "dolly-v2-12b":"databricks/dolly-v2-12b",
+ "stablelm-tuned-alpha-7b":"stabilityai/stablelm-tuned-alpha-7b",
+
+ "Llama-2-13b-hf":"meta-llama/Llama-2-13b-hf",
+ "Llama-2-70b-hf":"meta-llama/Llama-2-70b-hf",
+ "open_llama_13b":"openlm-research/open_llama_13b",
+ "vicuna-13b-v1.3":"lmsys/vicuna-13b-v1.3",
+ "koala":"young-geng/koala",
+
+ "mpt-7b":"mosaicml/mpt-7b",
+ "mpt-7b-storywriter":"mosaicml/mpt-7b-storywriter",
+ "mpt-30b":"mosaicml/mpt-30b",
+ "opt-66b":"facebook/opt-66b",
+ "opt-iml-max-30b":"facebook/opt-iml-max-30b",
+
+ "Qwen-7B":"Qwen/Qwen-7B",
+ "Qwen-14B":"Qwen/Qwen-14B",
+ "Qwen-7B-Chat":"Qwen/Qwen-7B-Chat",
+ "Qwen-14B-Chat":"Qwen/Qwen-14B-Chat",
+ },
}
# 选用的 Embedding 名称
-EMBEDDING_MODEL = "m3e-base"
+EMBEDDING_MODEL = "m3e-base" # 可以尝试最新的嵌入式sota模型:piccolo-large-zh
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto"
-llm_model_dict = {
- "chatglm-6b": {
- "local_model_path": "THUDM/chatglm-6b",
- "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
- "api_key": "EMPTY"
- },
+# LLM 名称
+LLM_MODEL = "chatglm2-6b"
- "chatglm2-6b": {
- "local_model_path": "THUDM/chatglm2-6b",
- "api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
- "api_key": "EMPTY"
- },
+# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
+LLM_DEVICE = "auto"
- "chatglm2-6b-32k": {
- "local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
- "api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
- "api_key": "EMPTY"
- },
+# 历史对话轮数
+HISTORY_LEN = 3
+# LLM通用对话参数
+TEMPERATURE = 0.7
+# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
+
+
+ONLINE_LLM_MODEL = {
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
@@ -75,28 +127,25 @@ llm_model_dict = {
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
"gpt-3.5-turbo": {
"api_base_url": "https://api.openai.com/v1",
- "api_key": "",
- "openai_proxy": ""
+ "api_key": "your OPENAI_API_KEY",
+ "openai_proxy": "your OPENAI_PROXY",
},
- # 线上模型。当前支持智谱AI。
- # 如果没有设置有效的local_model_path,则认为是在线模型API。
- # 请在server_config中为每个在线API设置不同的端口
+ # 线上模型。请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": {
- "api_base_url": "http://127.0.0.1:8888/v1",
"api_key": "",
- "provider": "ChatGLMWorker",
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
+ "provider": "ChatGLMWorker",
},
+ # 具体注册及api key获取请前往 https://api.minimax.chat/
"minimax-api": {
- "api_base_url": "http://127.0.0.1:8888/v1",
"group_id": "",
"api_key": "",
"is_pro": False,
"provider": "MiniMaxWorker",
},
+ # 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
"xinghuo-api": {
- "api_base_url": "http://127.0.0.1:8888/v1",
"APPID": "",
"APISecret": "",
"api_key": "",
@@ -105,140 +154,77 @@ llm_model_dict = {
},
# 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
"qianfan-api": {
- "version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见文档模型支持列表中千帆部分。
- "version_url": "", # 可以不填写version,直接填写在千帆申请模型发布的API地址
- "api_base_url": "http://127.0.0.1:8888/v1",
+ "version": "ernie-bot-turbo", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见官方文档。
+ "version_url": "", # 也可以不填写version,直接填写在千帆申请模型发布的API地址
"api_key": "",
"secret_key": "",
"provider": "QianFanWorker",
- }
-}
-
-# LLM 名称
-LLM_MODEL = "chatglm2-6b"
-
-# 历史对话轮数
-HISTORY_LEN = 3
-
-# LLM通用对话参数
-TEMPERATURE = 0.7
-# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
-
-
-# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
-LLM_DEVICE = "auto"
-
-# TextSplitter
-
-text_splitter_dict = {
- "ChineseRecursiveTextSplitter": {
- "source": "",
- "tokenizer_name_or_path": "",
},
- "SpacyTextSplitter": {
- "source": "huggingface",
- "tokenizer_name_or_path": "gpt2",
+ # 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379
+ "fangzhou-api": {
+ "version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model", 更多的见文档模型支持列表中方舟部分。
+ "version_url": "", # 可以不填写version,直接填写在方舟申请模型发布的API地址
+ "api_key": "",
+ "secret_key": "",
+ "provider": "FangZhouWorker",
},
- "RecursiveCharacterTextSplitter": {
- "source": "tiktoken",
- "tokenizer_name_or_path": "cl100k_base",
- },
-
- "MarkdownHeaderTextSplitter": {
- "headers_to_split_on":
- [
- ("#", "head1"),
- ("##", "head2"),
- ("###", "head3"),
- ("####", "head4"),
- ]
+ # 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
+ "qwen-api": {
+ "version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus"
+ "api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建
+ "provider": "QwenWorker",
},
}
-# TEXT_SPLITTER 名称
-TEXT_SPLITTER = "ChineseRecursiveTextSplitter"
-# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
-CHUNK_SIZE = 250
-
-# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
-OVERLAP_SIZE = 0
-
-
-# 日志存储路径
-LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
-if not os.path.exists(LOG_PATH):
- os.mkdir(LOG_PATH)
-
-# 知识库默认存储路径
-KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
-if not os.path.exists(KB_ROOT_PATH):
- os.mkdir(KB_ROOT_PATH)
-# 数据库默认存储路径。
-# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
-DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
-SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
-
-
-# 可选向量库类型及对应配置
-kbs_config = {
- "faiss": {
- },
- "milvus": {
- "host": "127.0.0.1",
- "port": "19530",
- "user": "",
- "password": "",
- "secure": False,
- },
- "pg": {
- "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
- }
-}
-
-# 默认向量库类型。可选:faiss, milvus, pg.
-DEFAULT_VS_TYPE = "faiss"
-
-# 缓存向量库数量
-CACHED_VS_NUM = 1
-
-# 知识库匹配向量数量
-VECTOR_SEARCH_TOP_K = 3
-
-# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
-SCORE_THRESHOLD = 1
-
-# 搜索引擎匹配结题数量
-SEARCH_ENGINE_TOP_K = 3
+# 通常情况下不需要更改以下内容
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
-# 基于本地知识问答的提示词模版(使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
-PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 指令>
-<已知信息>{{ context }}已知信息>
+VLLM_MODEL_DICT = {
+ "aquila-7b":"BAAI/Aquila-7B",
+ "aquilachat-7b":"BAAI/AquilaChat-7B",
-<问题>{{ question }}问题>"""
+ "baichuan-7b": "baichuan-inc/Baichuan-7B",
+ "baichuan-13b": "baichuan-inc/Baichuan-13B",
+ 'baichuan-13b-chat':'baichuan-inc/Baichuan-13B-Chat',
+ # 注意:bloom系列的tokenizer与model是分离的,因此虽然vllm支持,但与fschat框架不兼容
+ # "bloom":"bigscience/bloom",
+ # "bloomz":"bigscience/bloomz",
+ # "bloomz-560m":"bigscience/bloomz-560m",
+ # "bloomz-7b1":"bigscience/bloomz-7b1",
+ # "bloomz-1b7":"bigscience/bloomz-1b7",
-# API 是否开启跨域,默认为False,如果需要开启,请设置为True
-# is open cross domain
-OPEN_CROSS_DOMAIN = False
+ "internlm-7b":"internlm/internlm-7b",
+ "internlm-chat-7b":"internlm/internlm-chat-7b",
+ "falcon-7b":"tiiuae/falcon-7b",
+ "falcon-40b":"tiiuae/falcon-40b",
+ "falcon-rw-7b":"tiiuae/falcon-rw-7b",
+ "gpt2":"gpt2",
+ "gpt2-xl":"gpt2-xl",
+ "gpt-j-6b":"EleutherAI/gpt-j-6b",
+ "gpt4all-j":"nomic-ai/gpt4all-j",
+ "gpt-neox-20b":"EleutherAI/gpt-neox-20b",
+ "pythia-12b":"EleutherAI/pythia-12b",
+ "oasst-sft-4-pythia-12b-epoch-3.5":"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
+ "dolly-v2-12b":"databricks/dolly-v2-12b",
+ "stablelm-tuned-alpha-7b":"stabilityai/stablelm-tuned-alpha-7b",
+ "Llama-2-13b-hf":"meta-llama/Llama-2-13b-hf",
+ "Llama-2-70b-hf":"meta-llama/Llama-2-70b-hf",
+ "open_llama_13b":"openlm-research/open_llama_13b",
+ "vicuna-13b-v1.3":"lmsys/vicuna-13b-v1.3",
+ "koala":"young-geng/koala",
+ "mpt-7b":"mosaicml/mpt-7b",
+ "mpt-7b-storywriter":"mosaicml/mpt-7b-storywriter",
+ "mpt-30b":"mosaicml/mpt-30b",
+ "opt-66b":"facebook/opt-66b",
+ "opt-iml-max-30b":"facebook/opt-iml-max-30b",
-# Bing 搜索必备变量
-# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
-# 具体申请方式请见
-# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
-# 使用python创建bing api 搜索实例详见:
-# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
-BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
-# 注意不是bing Webmaster Tools的api key,
+ "Qwen-7B":"Qwen/Qwen-7B",
+ "Qwen-14B":"Qwen/Qwen-14B",
+ "Qwen-7B-Chat":"Qwen/Qwen-7B-Chat",
+ "Qwen-14B-Chat":"Qwen/Qwen-14B-Chat",
-# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
-# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
-BING_SUBSCRIPTION_KEY = ""
-
-# 是否开启中文标题加强,以及标题增强的相关配置
-# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
-# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
-ZH_TITLE_ENHANCE = False
+}
\ No newline at end of file
diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example
new file mode 100644
index 00000000..013f946f
--- /dev/null
+++ b/configs/prompt_config.py.example
@@ -0,0 +1,23 @@
+# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
+# 本配置文件支持热加载,修改prompt模板后无需重启服务。
+
+
+# LLM对话支持的变量:
+# - input: 用户输入内容
+
+# 知识库和搜索引擎对话支持的变量:
+# - context: 从检索结果拼接的知识文本
+# - question: 用户提出的问题
+
+
+PROMPT_TEMPLATES = {
+ # LLM对话模板
+ "llm_chat": "{{ input }}",
+
+ # 基于本地知识问答的提示词模
+ "knowledge_base_chat": """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 指令>
+
+<已知信息>{{ context }}已知信息>
+
+<问题>{{ question }}问题>""",
+}
diff --git a/configs/server_config.py.example b/configs/server_config.py.example
index 51f53dc3..12687878 100644
--- a/configs/server_config.py.example
+++ b/configs/server_config.py.example
@@ -1,5 +1,5 @@
-from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
-import httpx
+import sys
+from configs.model_config import LLM_DEVICE
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
HTTPX_DEFAULT_TIMEOUT = 300.0
@@ -8,8 +8,8 @@ HTTPX_DEFAULT_TIMEOUT = 300.0
# is open cross domain
OPEN_CROSS_DOMAIN = False
-# 各服务器默认绑定host
-DEFAULT_BIND_HOST = "127.0.0.1"
+# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
+DEFAULT_BIND_HOST = "0.0.0.0"
# webui.py server
WEBUI_SERVER = {
@@ -26,25 +26,27 @@ API_SERVER = {
# fastchat openai_api server
FSCHAT_OPENAI_API = {
"host": DEFAULT_BIND_HOST,
- "port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
+ "port": 20000,
}
# fastchat model_worker server
-# 这些模型必须是在model_config.llm_model_dict中正确配置的。
+# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = {
- # 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。
+ # 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
"default": {
"host": DEFAULT_BIND_HOST,
"port": 20002,
"device": LLM_DEVICE,
+ # False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题,参见doc/FAQ
+ "infer_turbo": "vllm" if sys.platform.startswith("linux") else False,
- # 多卡加载需要配置的参数
- # "gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
+ # model_worker多卡加载需要配置的参数
+ # "gpus": None, # 使用的GPU,以str的格式指定,如"0,1",如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定
# "num_gpus": 1, # 使用GPU的数量
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
- # 以下为非常用参数,可根据需要配置
+ # 以下为model_worker非常用参数,可根据需要配置
# "load_8bit": False, # 开启8bit量化
# "cpu_offloading": None,
# "gptq_ckpt": None,
@@ -60,21 +62,55 @@ FSCHAT_MODEL_WORKERS = {
# "stream_interval": 2,
# "no_register": False,
# "embed_in_truncate": False,
+
+ # 以下为vllm_woker配置参数,注意使用vllm必须有gpu,仅在Linux测试通过
+
+ # tokenizer = model_path # 如果tokenizer与model_path不一致在此处添加
+ # 'tokenizer_mode':'auto',
+ # 'trust_remote_code':True,
+ # 'download_dir':None,
+ # 'load_format':'auto',
+ # 'dtype':'auto',
+ # 'seed':0,
+ # 'worker_use_ray':False,
+ # 'pipeline_parallel_size':1,
+ # 'tensor_parallel_size':1,
+ # 'block_size':16,
+ # 'swap_space':4 , # GiB
+ # 'gpu_memory_utilization':0.90,
+ # 'max_num_batched_tokens':2560,
+ # 'max_num_seqs':256,
+ # 'disable_log_stats':False,
+ # 'conv_template':None,
+ # 'limit_worker_concurrency':5,
+ # 'no_register':False,
+ # 'num_gpus': 1
+ # 'engine_use_ray': False,
+ # 'disable_log_requests': False
+
},
- "baichuan-7b": { # 使用default中的IP和端口
- "device": "cpu",
+ # 可以如下示例方式更改默认配置
+ # "baichuan-7b": { # 使用default中的IP和端口
+ # "device": "cpu",
+ # },
+
+ "zhipu-api": { # 请为每个要运行的在线API设置不同的端口
+ "port": 21001,
},
- "zhipu-api": { # 请为每个在线API设置不同的端口
- "port": 20003,
+ "minimax-api": {
+ "port": 21002,
},
- "minimax-api": { # 请为每个在线API设置不同的端口
- "port": 20004,
- },
- "xinghuo-api": { # 请为每个在线API设置不同的端口
- "port": 20005,
+ "xinghuo-api": {
+ "port": 21003,
},
"qianfan-api": {
- "port": 20006,
+ "port": 21004,
+ },
+ "fangzhou-api": {
+ "port": 21005,
+ },
+ "qwen-api": {
+ "port": 21006,
},
}
diff --git a/docs/FAQ.md b/docs/FAQ.md
index 490eb25c..ad6683b1 100644
--- a/docs/FAQ.md
+++ b/docs/FAQ.md
@@ -107,7 +107,7 @@ embedding_model_dict = {
Q9: 执行 `python cli_demo.py`过程中,显卡内存爆了,提示 "OutOfMemoryError: CUDA out of memory"
-A9: 将 `VECTOR_SEARCH_TOP_K` 和 `LLM_HISTORY_LEN` 的值调低,比如 `VECTOR_SEARCH_TOP_K = 5` 和 `LLM_HISTORY_LEN = 2`,这样由 `query` 和 `context` 拼接得到的 `prompt` 会变短,会减少内存的占用。或者打开量化,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对`LOAD_IN_8BIT`参数进行修改
+A9: 将 `VECTOR_SEARCH_TOP_K` 和 `LLM_HISTORY_LEN` 的值调低,比如 `VECTOR_SEARCH_TOP_K = 5` 和 `LLM_HISTORY_LEN = 2`,这样由 `query` 和 `context` 拼接得到的 `prompt` 会变短,会减少内存的占用。或者打开量化,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对 `LOAD_IN_8BIT`参数进行修改
---
@@ -171,7 +171,6 @@ Q14: 修改配置中路径后,加载 text2vec-large-chinese 依然提示 `WARN
A14: 尝试更换 embedding,如 text2vec-base-chinese,请在 [configs/model_config.py](../configs/model_config.py) 文件中,修改 `text2vec-base`参数为本地路径,绝对路径或者相对路径均可
-
---
Q15: 使用pg向量库建表报错
@@ -182,4 +181,43 @@ A15: 需要手动安装对应的vector扩展(连接pg执行 CREATE EXTENSION IF
Q16: pymilvus 连接超时
-A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3
\ No newline at end of file
+A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3
+
+Q16: 使用vllm推理加速框架时,已经下载了模型但出现HuggingFace通信问题
+
+A16: 参照如下代码修改python环境下/site-packages/vllm/model_executor/weight_utils.py文件的prepare_hf_model_weights函数如下对应代码:
+
+```python
+
+ if not is_local:
+ # Use file lock to prevent multiple processes from
+ # downloading the same model weights at the same time.
+ model_path_temp = os.path.join(
+ os.getenv("HOME"),
+ ".cache/huggingface/hub",
+ "models--" + model_name_or_path.replace("/", "--"),
+ "snapshots/",
+ )
+ downloaded = False
+ if os.path.exists(model_path_temp):
+ temp_last_dir = os.listdir(model_path_temp)[-1]
+ model_path_temp = os.path.join(model_path_temp, temp_last_dir)
+ base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin")
+ files = glob.glob(base_pattern)
+ if len(files) > 0:
+ downloaded = True
+
+ if downloaded:
+ hf_folder = model_path_temp
+ else:
+ with get_lock(model_name_or_path, cache_dir):
+ hf_folder = snapshot_download(model_name_or_path,
+ allow_patterns=allow_patterns,
+ cache_dir=cache_dir,
+ tqdm_class=Disabledtqdm)
+ else:
+ hf_folder = model_name_or_path
+
+
+
+```
diff --git a/docs/自定义Agent.md b/docs/自定义Agent.md
new file mode 100644
index 00000000..940acf2f
--- /dev/null
+++ b/docs/自定义Agent.md
@@ -0,0 +1,80 @@
+## 自定义属于自己的Agent
+### 1. 创建自己的Agent工具
++ 开发者在```server/agent```文件中创建一个自己的文件,并将其添加到```tools.py```中。这样就完成了Tools的设定。
+
++ 当您创建了一个```custom_agent.py```文件,其中包含一个```work```函数,那么您需要在```tools.py```中添加如下代码:
+```python
+from custom_agent import work
+Tool.from_function(
+ func=work,
+ name="该函数的名字",
+ description=""
+ )
+```
++ 请注意,如果你确定在某一个工程中不会使用到某个工具,可以将其从Tools中移除,降低模型分类错误导致使用错误工具的风险。
+
+### 2. 修改 custom_template.py文件
+开发者需要根据自己选择的大模型设定适合该模型的Agent Prompt和自自定义返回格式。
+在我们的代码中,提供了默认的两种方式,一种是适配于GPT和Qwen的提示词:
+```python
+"""
+ Answer the following questions as best you can. You have access to the following tools:
+
+ {tools}
+ Use the following format:
+
+ Question: the input question you must answer
+ Thought: you should always think about what to do
+ Action: the action to take, should be one of [{tool_names}]
+ Action Input: the input to the action
+ Observation: the result of the action
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
+ Thought: I now know the final answer
+ Final Answer: the final answer to the original input question
+
+ Begin!
+
+ history:
+ {history}
+
+ Question: {input}
+ Thought: {agent_scratchpad}
+"""
+```
+
+另一种是适配于GLM-130B的提示词:
+```python
+"""
+尽可能地回答以下问题。你可以使用以下工具:{tools}
+请按照以下格式进行:
+Question: 需要你回答的输入问题
+Thought: 你应该总是思考该做什么
+Action: 需要使用的工具,应该是[{tool_names}]中的一个
+Action Input: 传入工具的内容
+Observation: 行动的结果
+ ... (这个Thought/Action/Action Input/Observation可以重复N次)
+Thought: 我现在知道最后的答案
+Final Answer: 对原始输入问题的最终答案
+
+现在开始!
+
+之前的对话:
+{history}
+
+New question: {input}
+Thought: {agent_scratchpad}
+"""
+```
+
+### 3. 局限性
+1. 在我们的实验中,小于70B级别的模型,若不经过微调,很难达到较好的效果。因此,我们建议开发者使用大于70B级别的模型进行微调,以达到更好的效果。
+2. 由于Agent的脆弱性,temperture参数的设置对于模型的效果有很大的影响。我们建议开发者在使用自定义Agent时,对于不同的模型,将其设置成0.1以下,以达到更好的效果。
+3. 即使使用了大于70B级别的模型,开发者也应该在Prompt上进行深度优化,以让模型能成功的选择工具并完成任务。
+
+
+### 4. 我们已经支持的Agent
+我们为开发者编写了三个运用大模型执行的Agent,分别是:
+1. 翻译工具,实现对输入的任意语言翻译。
+2. 数学工具,使用LLMMathChain 实现数学计算。
+3. 天气工具,使用自定义的LLMWetherChain实现天气查询,调用和风天气API。
+4. 我们支持Langchain支持的Agent工具,在代码中,我们已经提供了Shell和Google Search两个工具的实现。
\ No newline at end of file
diff --git a/img/official_account.png b/img/official_account.png
new file mode 100644
index 00000000..8c0998b0
Binary files /dev/null and b/img/official_account.png differ
diff --git a/img/qr_code_61.jpg b/img/qr_code_61.jpg
deleted file mode 100644
index 3d7d163b..00000000
Binary files a/img/qr_code_61.jpg and /dev/null differ
diff --git a/img/qr_code_62.jpg b/img/qr_code_62.jpg
deleted file mode 100644
index e90b353c..00000000
Binary files a/img/qr_code_62.jpg and /dev/null differ
diff --git a/init_database.py b/init_database.py
index 42a18c53..9e807a8e 100644
--- a/init_database.py
+++ b/init_database.py
@@ -1,44 +1,92 @@
-from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder
+from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
-from startup import dump_server_info
from datetime import datetime
+import sys
if __name__ == "__main__":
import argparse
- parser = argparse.ArgumentParser()
- parser.formatter_class = argparse.RawTextHelpFormatter
+ parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
parser.add_argument(
+ "-r",
"--recreate-vs",
action="store_true",
help=('''
- recreate all vector store.
+ recreate vector store.
use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/EMBEDDING_MODEL changed.
- if your vector store is ready with the configs, just skip this option to fill info to database only.
'''
)
)
- args = parser.parse_args()
+ parser.add_argument(
+ "-u",
+ "--update-in-db",
+ action="store_true",
+ help=('''
+ update vector store for files exist in database.
+ use this option if you want to recreate vectors for files exist in db and skip files exist in local folder only.
+ '''
+ )
+ )
+ parser.add_argument(
+ "-i",
+ "--increament",
+ action="store_true",
+ help=('''
+ update vector store for files exist in local folder and not exist in database.
+ use this option if you want to create vectors increamentally.
+ '''
+ )
+ )
+ parser.add_argument(
+ "--prune-db",
+ action="store_true",
+ help=('''
+ delete docs in database that not existed in local folder.
+ it is used to delete database docs after user deleted some doc files in file browser
+ '''
+ )
+ )
+ parser.add_argument(
+ "--prune-folder",
+ action="store_true",
+ help=('''
+ delete doc files in local folder that not existed in database.
+ is is used to free local disk space by delete unused doc files.
+ '''
+ )
+ )
+ parser.add_argument(
+ "--kb-name",
+ type=str,
+ nargs="+",
+ default=[],
+ help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
+ )
- dump_server_info()
-
- start_time = datetime.now()
-
- if args.recreate_vs:
- reset_tables()
- print("database talbes reseted")
- print("recreating all vector stores")
- recreate_all_vs()
+ if len(sys.argv) <= 1:
+ parser.print_help()
else:
- create_tables()
- print("database talbes created")
- print("filling kb infos to database")
- for kb in list_kbs_from_folder():
- folder2db(kb, "fill_info_only")
+ args = parser.parse_args()
+ start_time = datetime.now()
- end_time = datetime.now()
- print(f"总计用时: {end_time-start_time}")
+ create_tables() # confirm tables exist
+ if args.recreate_vs:
+ reset_tables()
+ print("database talbes reseted")
+ print("recreating all vector stores")
+ folder2db(kb_names=args.kb_name, mode="recreate_vs")
+ elif args.update_in_db:
+ folder2db(kb_names=args.kb_name, mode="update_in_db")
+ elif args.increament:
+ folder2db(kb_names=args.kb_name, mode="increament")
+ elif args.prune_db:
+ prune_db_docs(args.kb_name)
+ elif args.prune_folder:
+ prune_folder_files(args.kb_name)
+
+ end_time = datetime.now()
+ print(f"总计用时: {end_time-start_time}")
diff --git a/requirements.txt b/requirements.txt
index 5699d46d..a1d164b9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,12 @@
-langchain==0.0.287
-fschat[model_worker]==0.2.28
+langchain>=0.0.302
+fschat[model_worker]==0.2.29
openai
sentence_transformers
-transformers>=4.31.0
-torch~=2.0.0
-fastapi~=0.99.1
+transformers>=4.33.0
+torch>=2.0.1
+torchvision
+torchaudio
+fastapi>=0.103.1
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
@@ -23,6 +25,12 @@ pathlib
pytest
scikit-learn
numexpr
+vllm==0.1.7; sys_platform == "linux"
+# online api libs
+# zhipuai
+# dashscope>=1.10.0 # qwen
+# qianfan
+# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
@@ -34,9 +42,13 @@ pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
-streamlit-chatbox >=1.1.6, <=1.1.7
+streamlit-chatbox>=1.1.9
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog
tqdm
websockets
+tiktoken
+einops
+scipy
+transformers_stream_generator==0.0.4
diff --git a/requirements_api.txt b/requirements_api.txt
index c56c07bf..e195d74f 100644
--- a/requirements_api.txt
+++ b/requirements_api.txt
@@ -1,10 +1,12 @@
-langchain==0.0.287
-fschat[model_worker]==0.2.28
+langchain>=0.0.302
+fschat[model_worker]==0.2.29
openai
sentence_transformers
-transformers>=4.31.0
-torch~=2.0.0
-fastapi~=0.99.1
+transformers>=4.33.0
+torch >=2.0.1
+torchvision
+torchaudio
+fastapi>=0.103.1
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
@@ -17,12 +19,19 @@ accelerate
spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.2
-
requests
pathlib
pytest
scikit-learn
numexpr
+vllm==0.1.7; sys_platform == "linux"
+
+
+# online api libs
+# zhipuai
+# dashscope>=1.10.0 # qwen
+# qianfan
+# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
diff --git a/requirements_webui.txt b/requirements_webui.txt
index f67f968f..9caf085a 100644
--- a/requirements_webui.txt
+++ b/requirements_webui.txt
@@ -3,7 +3,7 @@ pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
-streamlit-chatbox >=1.1.6, <=1.1.7
+streamlit-chatbox>=1.1.9
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
nltk
diff --git a/server/agent/callbacks.py b/server/agent/callbacks.py
new file mode 100644
index 00000000..394d2279
--- /dev/null
+++ b/server/agent/callbacks.py
@@ -0,0 +1,109 @@
+from uuid import UUID
+from langchain.callbacks import AsyncIteratorCallbackHandler
+import json
+import asyncio
+from typing import Any, Dict, List, Optional
+
+from langchain.schema import AgentFinish, AgentAction
+from langchain.schema.output import LLMResult
+
+
+def dumps(obj: Dict) -> str:
+ return json.dumps(obj, ensure_ascii=False)
+
+
+class Status:
+ start: int = 1
+ running: int = 2
+ complete: int = 3
+ agent_action: int = 4
+ agent_finish: int = 5
+ error: int = 6
+ make_tool: int = 7
+
+
+class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
+ def __init__(self):
+ super().__init__()
+ self.queue = asyncio.Queue()
+ self.done = asyncio.Event()
+ self.cur_tool = {}
+ self.out = True
+
+ async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
+ parent_run_id: UUID | None = None, tags: List[str] | None = None,
+ metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
+ self.cur_tool = {
+ "tool_name": serialized["name"],
+ "input_str": input_str,
+ "output_str": "",
+ "status": Status.agent_action,
+ "run_id": run_id.hex,
+ "llm_token": "",
+ "final_answer": "",
+ "error": "",
+ }
+ self.queue.put_nowait(dumps(self.cur_tool))
+
+ async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
+ tags: List[str] | None = None, **kwargs: Any) -> None:
+ self.out = True
+ self.cur_tool.update(
+ status=Status.agent_finish,
+ output_str=output.replace("Answer:", ""),
+ )
+ self.queue.put_nowait(dumps(self.cur_tool))
+
+ async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
+ parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
+ self.cur_tool.update(
+ status=Status.error,
+ error=str(error),
+ )
+ self.queue.put_nowait(dumps(self.cur_tool))
+
+ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+ if token:
+ if "Action" in token:
+ self.out = False
+ self.cur_tool.update(
+ status=Status.running,
+ llm_token="\n\n",
+ )
+ self.queue.put_nowait(dumps(self.cur_tool))
+ if self.out:
+ self.cur_tool.update(
+ status=Status.running,
+ llm_token=token,
+ )
+ self.queue.put_nowait(dumps(self.cur_tool))
+
+ async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
+ self.cur_tool.update(
+ status=Status.start,
+ llm_token="",
+ )
+ self.queue.put_nowait(dumps(self.cur_tool))
+
+ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
+ self.out = True
+ self.cur_tool.update(
+ status=Status.complete,
+ llm_token="",
+ )
+ self.queue.put_nowait(dumps(self.cur_tool))
+
+ async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
+ self.out = True
+ self.cur_tool.update(
+ status=Status.error,
+ error=str(error),
+ )
+ self.queue.put_nowait(dumps(self.cur_tool))
+
+ async def on_agent_finish(
+ self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
+ tags: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.cur_tool = {}
diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py
new file mode 100644
index 00000000..25697a6e
--- /dev/null
+++ b/server/agent/custom_template.py
@@ -0,0 +1,64 @@
+from langchain.agents import Tool, AgentOutputParser
+from langchain.prompts import StringPromptTemplate
+from typing import List, Union
+from langchain.schema import AgentAction, AgentFinish
+import re
+
+class CustomPromptTemplate(StringPromptTemplate):
+ # The template to use
+ template: str
+ # The list of tools available
+ tools: List[Tool]
+
+ def format(self, **kwargs) -> str:
+ # Get the intermediate steps (AgentAction, Observation tuples)
+ # Format them in a particular way
+ intermediate_steps = kwargs.pop("intermediate_steps")
+ thoughts = ""
+ for action, observation in intermediate_steps:
+ thoughts += action.log
+ thoughts += f"\nObservation: {observation}\nThought: "
+ # Set the agent_scratchpad variable to that value
+ kwargs["agent_scratchpad"] = thoughts
+ # Create a tools variable from the list of tools provided
+ kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
+ # Create a list of tool names for the tools provided
+ kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
+ return self.template.format(**kwargs)
+class CustomOutputParser(AgentOutputParser):
+
+ def parse(self, llm_output: str) -> AgentFinish | AgentAction | str:
+ # Check if agent should finish
+ if "Final Answer:" in llm_output:
+ return AgentFinish(
+ # Return values is generally always a dictionary with a single `output` key
+ # It is not recommended to try anything else at the moment :)
+ return_values={"output": llm_output.replace("Final Answer:", "").strip()},
+ log=llm_output,
+ )
+ # Parse out the action and action input
+ regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
+ match = re.search(regex, llm_output, re.DOTALL)
+ if not match:
+ return AgentFinish(
+ return_values={"output": f"调用agent失败: `{llm_output}`"},
+ log=llm_output,
+ )
+ action = match.group(1).strip()
+ action_input = match.group(2)
+ # Return the action and action input
+ try:
+ ans = AgentAction(
+ tool=action,
+ tool_input=action_input.strip(" ").strip('"'),
+ log=llm_output
+ )
+ return ans
+ except:
+ return AgentFinish(
+ return_values={"output": f"调用agent失败: `{llm_output}`"},
+ log=llm_output,
+ )
+
+
+
diff --git a/server/agent/google_search.py b/server/agent/google_search.py
new file mode 100644
index 00000000..979d478c
--- /dev/null
+++ b/server/agent/google_search.py
@@ -0,0 +1,8 @@
+import os
+os.environ["GOOGLE_CSE_ID"] = ""
+os.environ["GOOGLE_API_KEY"] = ""
+
+from langchain.tools import GoogleSearchResults
+def google_search(query: str):
+ tool = GoogleSearchResults()
+ return tool.run(tool_input=query)
\ No newline at end of file
diff --git a/server/agent/math.py b/server/agent/math.py
new file mode 100644
index 00000000..a00667af
--- /dev/null
+++ b/server/agent/math.py
@@ -0,0 +1,70 @@
+from langchain.prompts import PromptTemplate
+from langchain.chains import LLMMathChain
+from server.utils import wrap_done, get_ChatOpenAI
+from configs.model_config import LLM_MODEL, TEMPERATURE
+from langchain.chat_models import ChatOpenAI
+from langchain.callbacks.manager import CallbackManagerForToolRun
+
+_PROMPT_TEMPLATE = """将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
+问题: ${{包含数学问题的问题。}}
+```text
+${{解决问题的单行数学表达式}}
+```
+...numexpr.evaluate(query)...
+```output
+${{运行代码的输出}}
+```
+答案: ${{答案}}
+
+这是两个例子:
+
+问题: 37593 * 67是多少?
+```text
+37593 * 67
+```
+...numexpr.evaluate("37593 * 67")...
+```output
+2518731
+
+答案: 2518731
+
+问题: 37593的五次方根是多少?
+```text
+37593**(1/5)
+```
+...numexpr.evaluate("37593**(1/5)")...
+```output
+8.222831614237718
+
+答案: 8.222831614237718
+
+
+问题: 2的平方是多少?
+```text
+2 ** 2
+```
+...numexpr.evaluate("2 ** 2")...
+```output
+4
+
+答案: 4
+
+
+现在,这是我的问题:
+问题: {question}
+"""
+PROMPT = PromptTemplate(
+ input_variables=["question"],
+ template=_PROMPT_TEMPLATE,
+)
+
+
+def calculate(query: str):
+ model = get_ChatOpenAI(
+ streaming=False,
+ model_name=LLM_MODEL,
+ temperature=TEMPERATURE,
+ )
+ llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
+ ans = llm_math.run(query)
+ return ans
diff --git a/server/agent/shell.py b/server/agent/shell.py
new file mode 100644
index 00000000..4dfee0bb
--- /dev/null
+++ b/server/agent/shell.py
@@ -0,0 +1,5 @@
+from langchain.tools import ShellTool
+def shell(query: str):
+ tool = ShellTool()
+ return tool.run(tool_input=query)
+
diff --git a/server/agent/tools.py b/server/agent/tools.py
new file mode 100644
index 00000000..7c6793f1
--- /dev/null
+++ b/server/agent/tools.py
@@ -0,0 +1,40 @@
+import sys
+import os
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+
+from server.agent.math import calculate
+from server.agent.translator import translate
+from server.agent.weather import weathercheck
+from server.agent.shell import shell
+from server.agent.google_search import google_search
+from langchain.agents import Tool
+
+tools = [
+ Tool.from_function(
+ func=calculate,
+ name="计算器工具",
+ description="进行简单的数学运算"
+ ),
+ Tool.from_function(
+ func=translate,
+ name="翻译工具",
+ description="翻译各种语言"
+ ),
+ Tool.from_function(
+ func=weathercheck,
+ name="天气查询工具",
+ description="查询天气",
+ ),
+ Tool.from_function(
+ func=shell,
+ name="shell工具",
+ description="使用命令行工具输出",
+ ),
+ Tool.from_function(
+ func=google_search,
+ name="谷歌搜索工具",
+ description="使用谷歌搜索",
+ )
+]
+tool_names = [tool.name for tool in tools]
diff --git a/server/agent/translator.py b/server/agent/translator.py
new file mode 100644
index 00000000..8466cf83
--- /dev/null
+++ b/server/agent/translator.py
@@ -0,0 +1,55 @@
+from langchain.prompts import PromptTemplate
+from langchain.chains import LLMChain
+import sys
+import os
+
+from server.utils import get_ChatOpenAI
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+from langchain.chains.llm_math.prompt import PROMPT
+from configs.model_config import LLM_MODEL,TEMPERATURE
+
+_PROMPT_TEMPLATE = '''
+# 指令
+接下来,作为一个专业的翻译专家,当我给出句子或段落时,你将提供通顺且具有可读性的对应语言的翻译。注意:
+1. 确保翻译结果流畅且易于理解
+2. 无论提供的是陈述句或疑问句,只进行翻译
+3. 不添加与原文无关的内容
+
+原文: ${{用户需要翻译的原文和目标语言}}
+{question}
+```output
+${{翻译结果}}
+```
+答案: ${{答案}}
+
+以下是两个例子
+问题: 翻译13成英语
+```text
+13 英语
+```output
+thirteen
+以下是两个例子
+问题: 翻译 我爱你 成法语
+```text
+13 法语
+```output
+Je t'aime.
+'''
+
+PROMPT = PromptTemplate(
+ input_variables=["question"],
+ template=_PROMPT_TEMPLATE,
+)
+
+
+def translate(query: str):
+ model = get_ChatOpenAI(
+ streaming=False,
+ model_name=LLM_MODEL,
+ temperature=TEMPERATURE,
+ )
+ llm_translate = LLMChain(llm=model, prompt=PROMPT)
+ ans = llm_translate.run(query)
+
+ return ans
diff --git a/server/agent/weather.py b/server/agent/weather.py
new file mode 100644
index 00000000..3e5a37bf
--- /dev/null
+++ b/server/agent/weather.py
@@ -0,0 +1,365 @@
+## 使用和风天气API查询天气
+from __future__ import annotations
+
+## 单独运行的时候需要添加
+import sys
+import os
+# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+
+
+from server.utils import get_ChatOpenAI
+
+
+import re
+import warnings
+from typing import Dict
+
+from langchain.callbacks.manager import (
+ AsyncCallbackManagerForChainRun,
+ CallbackManagerForChainRun,
+)
+from langchain.chains.base import Chain
+from langchain.chains.llm import LLMChain
+from langchain.pydantic_v1 import Extra, root_validator
+from langchain.schema import BasePromptTemplate
+from langchain.schema.language_model import BaseLanguageModel
+import requests
+from typing import List, Any, Optional
+from configs.model_config import LLM_MODEL, TEMPERATURE
+
+## 使用和风天气API查询天气
+KEY = ""
+
+def get_city_info(location, adm, key):
+ base_url = 'https://geoapi.qweather.com/v2/city/lookup?'
+ params = {'location': location, 'adm': adm, 'key': key}
+ response = requests.get(base_url, params=params)
+ data = response.json()
+ return data
+
+
+from datetime import datetime
+
+
+def format_weather_data(data):
+ hourly_forecast = data['hourly']
+ formatted_data = ''
+ for forecast in hourly_forecast:
+ # 将预报时间转换为datetime对象
+ forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z')
+ # 获取预报时间的时区
+ forecast_tz = forecast_time.tzinfo
+ # 获取当前时间(使用预报时间的时区)
+ now = datetime.now(forecast_tz)
+ # 计算预报日期与当前日期的差值
+ days_diff = (forecast_time.date() - now.date()).days
+ if days_diff == 0:
+ forecast_date_str = '今天'
+ elif days_diff == 1:
+ forecast_date_str = '明天'
+ elif days_diff == 2:
+ forecast_date_str = '后天'
+ else:
+ forecast_date_str = str(days_diff) + '天后'
+ forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M')
+ # 计算预报时间与当前时间的差值
+ time_diff = forecast_time - now
+ # 将差值转换为小时
+ hours_diff = time_diff.total_seconds() // 3600
+ if hours_diff < 1:
+ hours_diff_str = '1小时后'
+ elif hours_diff >= 24:
+ # 如果超过24小时,转换为天数
+ days_diff = hours_diff // 24
+ hours_diff_str = str(int(days_diff)) + '天后'
+ else:
+ hours_diff_str = str(int(hours_diff)) + '小时后'
+ # 将预报时间和当前时间的差值添加到输出中
+ formatted_data += '预报时间: ' + hours_diff_str + '\n'
+ formatted_data += '具体时间: ' + forecast_time_str + '\n'
+ formatted_data += '温度: ' + forecast['temp'] + '°C\n'
+ formatted_data += '天气: ' + forecast['text'] + '\n'
+ formatted_data += '风向: ' + forecast['windDir'] + '\n'
+ formatted_data += '风速: ' + forecast['windSpeed'] + '级\n'
+ formatted_data += '湿度: ' + forecast['humidity'] + '%\n'
+ formatted_data += '降水概率: ' + forecast['pop'] + '%\n'
+ # formatted_data += '降水量: ' + forecast['precip'] + 'mm\n'
+ formatted_data += '\n\n'
+ return formatted_data
+
+
+def get_weather(key, location_id, time: str = "24"):
+ if time:
+ url = "https://devapi.qweather.com/v7/weather/" + time + "h?"
+ else:
+ time = "3" # 免费订阅只能查看3天的天气
+ url = "https://devapi.qweather.com/v7/weather/" + time + "d?"
+ params = {
+ 'location': location_id,
+ 'key': key,
+ }
+ response = requests.get(url, params=params)
+ data = response.json()
+ return format_weather_data(data)
+
+
+def split_query(query):
+ parts = query.split()
+ location = parts[0] if parts[0] != 'None' else parts[1]
+ adm = parts[1]
+ time = parts[2]
+ return location, adm, time
+
+
+def weather(query):
+ location, adm, time = split_query(query)
+ key = KEY
+ if time != "None" and int(time) > 24:
+ return "只能查看24小时内的天气,无法回答"
+ if time == "None":
+ time = "24" # 免费的版本只能24小时内的天气
+ if key == "":
+ return "请先在代码中填入和风天气API Key"
+ city_info = get_city_info(location=location, adm=adm, key=key)
+ location_id = city_info['location'][0]['id']
+ weather_data = get_weather(key=key, location_id=location_id, time=time)
+ return weather_data
+
+
+class LLMWeatherChain(Chain):
+ llm_chain: LLMChain
+ llm: Optional[BaseLanguageModel] = None
+ """[Deprecated] LLM wrapper to use."""
+ prompt: BasePromptTemplate
+ """[Deprecated] Prompt to use to translate to python if necessary."""
+ input_key: str = "question" #: :meta private:
+ output_key: str = "answer" #: :meta private:
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+ arbitrary_types_allowed = True
+
+ @root_validator(pre=True)
+ def raise_deprecation(cls, values: Dict) -> Dict:
+ if "llm" in values:
+ warnings.warn(
+ "Directly instantiating an LLMWeatherChain with an llm is deprecated. "
+ "Please instantiate with llm_chain argument or using the from_llm "
+ "class method."
+ )
+ if "llm_chain" not in values and values["llm"] is not None:
+ prompt = values.get("prompt", PROMPT)
+ values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
+ return values
+
+ @property
+ def input_keys(self) -> List[str]:
+ """Expect input key.
+
+ :meta private:
+ """
+ return [self.input_key]
+
+ @property
+ def output_keys(self) -> List[str]:
+ """Expect output key.
+
+ :meta private:
+ """
+ return [self.output_key]
+
+ def _evaluate_expression(self, expression: str) -> str:
+ try:
+ output = weather(expression)
+ except Exception as e:
+ output = "输入的信息有误,请再次尝试"
+ # raise ValueError(f"错误: {expression},输入的信息不对")
+
+ return output
+
+ def _process_llm_result(
+ self, llm_output: str, run_manager: CallbackManagerForChainRun
+ ) -> Dict[str, str]:
+
+ run_manager.on_text(llm_output, color="green", verbose=self.verbose)
+
+ llm_output = llm_output.strip()
+ text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
+ if text_match:
+ expression = text_match.group(1)
+ output = self._evaluate_expression(expression)
+ run_manager.on_text("\nAnswer: ", verbose=self.verbose)
+ run_manager.on_text(output, color="yellow", verbose=self.verbose)
+ answer = "Answer: " + output
+ elif llm_output.startswith("Answer:"):
+ answer = llm_output
+ elif "Answer:" in llm_output:
+ answer = "Answer: " + llm_output.split("Answer:")[-1]
+ else:
+ raise ValueError(f"unknown format from LLM: {llm_output}")
+ return {self.output_key: answer}
+
+ async def _aprocess_llm_result(
+ self,
+ llm_output: str,
+ run_manager: AsyncCallbackManagerForChainRun,
+ ) -> Dict[str, str]:
+ await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
+ llm_output = llm_output.strip()
+ text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
+ if text_match:
+ expression = text_match.group(1)
+ output = self._evaluate_expression(expression)
+ await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
+ await run_manager.on_text(output, color="yellow", verbose=self.verbose)
+ answer = "Answer: " + output
+ elif llm_output.startswith("Answer:"):
+ answer = llm_output
+ elif "Answer:" in llm_output:
+ answer = "Answer: " + llm_output.split("Answer:")[-1]
+ else:
+ raise ValueError(f"unknown format from LLM: {llm_output}")
+ return {self.output_key: answer}
+
+ def _call(
+ self,
+ inputs: Dict[str, str],
+ run_manager: Optional[CallbackManagerForChainRun] = None,
+ ) -> Dict[str, str]:
+ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
+ _run_manager.on_text(inputs[self.input_key])
+ llm_output = self.llm_chain.predict(
+ question=inputs[self.input_key],
+ stop=["```output"],
+ callbacks=_run_manager.get_child(),
+ )
+ return self._process_llm_result(llm_output, _run_manager)
+
+ async def _acall(
+ self,
+ inputs: Dict[str, str],
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
+ ) -> Dict[str, str]:
+ _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
+ await _run_manager.on_text(inputs[self.input_key])
+ llm_output = await self.llm_chain.apredict(
+ question=inputs[self.input_key],
+ stop=["```output"],
+ callbacks=_run_manager.get_child(),
+ )
+ return await self._aprocess_llm_result(llm_output, _run_manager)
+
+ @property
+ def _chain_type(self) -> str:
+ return "llm_weather_chain"
+
+ @classmethod
+ def from_llm(
+ cls,
+ llm: BaseLanguageModel,
+ prompt: BasePromptTemplate,
+ **kwargs: Any,
+ ) -> LLMWeatherChain:
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
+ return cls(llm_chain=llm_chain, **kwargs)
+
+
+from langchain.prompts import PromptTemplate
+
+_PROMPT_TEMPLATE = """用户将会向您咨询天气问题,您不需要自己回答天气问题,而是将用户提问的信息提取出来区,市和时间三个元素后使用我为你编写好的工具进行查询并返回结果,格式为 区+市+时间 每个元素用空格隔开。如果缺少信息,则用 None 代替。
+问题: ${{用户的问题}}
+
+```text
+
+${{拆分的区,市和时间}}
+```
+
+... weather(提取后的关键字,用空格隔开)...
+```output
+
+${{提取后的答案}}
+```
+答案: ${{答案}}
+
+这是两个例子:
+问题: 上海浦东未来1小时天气情况?
+
+```text
+浦东 上海 1
+```
+...weather(浦东 上海 1)...
+
+```output
+
+预报时间: 1小时后
+具体时间: 今天 18:00
+温度: 24°C
+天气: 多云
+风向: 西南风
+风速: 7级
+湿度: 88%
+降水概率: 16%
+
+Answer:
+预报时间: 1小时后
+具体时间: 今天 18:00
+温度: 24°C
+天气: 多云
+风向: 西南风
+风速: 7级
+湿度: 88%
+降水概率: 16%
+
+问题: 北京市朝阳区未来24小时天气如何?
+```text
+
+朝阳 北京 24
+```
+...weather(朝阳 北京 24)...
+```output
+预报时间: 23小时后
+具体时间: 明天 17:00
+温度: 26°C
+天气: 霾
+风向: 西南风
+风速: 11级
+湿度: 65%
+降水概率: 20%
+Answer:
+预报时间: 23小时后
+具体时间: 明天 17:00
+温度: 26°C
+天气: 霾
+风向: 西南风
+风速: 11级
+湿度: 65%
+降水概率: 20%
+
+现在,这是我的问题:
+问题: {question}
+"""
+PROMPT = PromptTemplate(
+ input_variables=["question"],
+ template=_PROMPT_TEMPLATE,
+)
+
+
+def weathercheck(query: str):
+ model = get_ChatOpenAI(
+ streaming=False,
+ model_name=LLM_MODEL,
+ temperature=TEMPERATURE,
+ )
+ llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
+ ans = llm_weather.run(query)
+ return ans
+
+if __name__ == '__main__':
+
+ ## 检测api是否能正确返回
+ query = "上海浦东未来1小时天气情况"
+ # ans = weathercheck(query)
+ ans = weather("浦东 上海 1")
+ print(ans)
\ No newline at end of file
diff --git a/server/api.py b/server/api.py
index 357a0678..ea098d68 100644
--- a/server/api.py
+++ b/server/api.py
@@ -12,12 +12,12 @@ import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
- search_engine_chat)
+ search_engine_chat, agent_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
-from server.llm_api import list_llm_models, change_llm_model, stop_llm_model
+from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List
@@ -67,6 +67,10 @@ def create_app():
tags=["Chat"],
summary="与搜索引擎对话")(search_engine_chat)
+ app.post("/chat/agent_chat",
+ tags=["Chat"],
+ summary="与agent对话")(agent_chat)
+
# Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
@@ -125,20 +129,25 @@ def create_app():
)(recreate_vector_store)
# LLM模型相关接口
- app.post("/llm_model/list_models",
- tags=["LLM Model Management"],
- summary="列出当前已加载的模型",
- )(list_llm_models)
+ app.post("/llm_model/list_running_models",
+ tags=["LLM Model Management"],
+ summary="列出当前已加载的模型",
+ )(list_running_models)
+
+ app.post("/llm_model/list_config_models",
+ tags=["LLM Model Management"],
+ summary="列出configs已配置的模型",
+ )(list_config_models)
app.post("/llm_model/stop",
- tags=["LLM Model Management"],
- summary="停止指定的LLM模型(Model Worker)",
- )(stop_llm_model)
+ tags=["LLM Model Management"],
+ summary="停止指定的LLM模型(Model Worker)",
+ )(stop_llm_model)
app.post("/llm_model/change",
- tags=["LLM Model Management"],
- summary="切换指定的LLM模型(Model Worker)",
- )(change_llm_model)
+ tags=["LLM Model Management"],
+ summary="切换指定的LLM模型(Model Worker)",
+ )(change_llm_model)
return app
diff --git a/server/chat/__init__.py b/server/chat/__init__.py
index 136bad64..62fe430c 100644
--- a/server/chat/__init__.py
+++ b/server/chat/__init__.py
@@ -2,3 +2,4 @@ from .chat import chat
from .knowledge_base_chat import knowledge_base_chat
from .openai_chat import openai_chat
from .search_engine_chat import search_engine_chat
+from .agent_chat import agent_chat
\ No newline at end of file
diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py
new file mode 100644
index 00000000..c08ae81c
--- /dev/null
+++ b/server/chat/agent_chat.py
@@ -0,0 +1,126 @@
+from langchain.memory import ConversationBufferWindowMemory
+from server.agent.tools import tools, tool_names
+from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status, dumps
+from langchain.agents import AgentExecutor, LLMSingleActionAgent
+from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
+from fastapi import Body
+from fastapi.responses import StreamingResponse
+from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
+from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
+from langchain.chains import LLMChain
+from typing import AsyncIterable, Optional
+import asyncio
+from typing import List
+from server.chat.utils import History
+import json
+
+
+async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
+ history: List[History] = Body([],
+ description="历史对话",
+ examples=[[
+ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
+ {"role": "assistant", "content": "虎头虎脑"}]]
+ ),
+ stream: bool = Body(False, description="流式输出"),
+ model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
+ temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ prompt_name: str = Body("agent_chat",
+ description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
+ # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
+ ):
+ history = [History.from_data(h) for h in history]
+
+ async def agent_chat_iterator(
+ query: str,
+ history: Optional[List[History]],
+ model_name: str = LLM_MODEL,
+ prompt_name: str = prompt_name,
+ ) -> AsyncIterable[str]:
+ callback = CustomAsyncIteratorCallbackHandler()
+ model = get_ChatOpenAI(
+ model_name=model_name,
+ temperature=temperature,
+ )
+
+ prompt_template = CustomPromptTemplate(
+ template=get_prompt_template(prompt_name),
+ tools=tools,
+ input_variables=["input", "intermediate_steps", "history"]
+ )
+ output_parser = CustomOutputParser()
+ llm_chain = LLMChain(llm=model, prompt=prompt_template)
+ agent = LLMSingleActionAgent(
+ llm_chain=llm_chain,
+ output_parser=output_parser,
+ stop=["Observation:", "Observation:\n", "<|im_end|>"], # Qwen模型中使用这个
+ # stop=["Observation:", "Observation:\n"], # 其他模型,注意模板
+ allowed_tools=tool_names,
+ )
+ # 把history转成agent的memory
+ memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
+
+ for message in history:
+ # 检查消息的角色
+ if message.role == 'user':
+ # 添加用户消息
+ memory.chat_memory.add_user_message(message.content)
+ else:
+ # 添加AI消息
+ memory.chat_memory.add_ai_message(message.content)
+ agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
+ tools=tools,
+ verbose=True,
+ memory=memory,
+ )
+ input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
+ task = asyncio.create_task(wrap_done(
+ agent_executor.acall(query, callbacks=[callback], include_run_info=True),
+ callback.done),
+ )
+ if stream:
+ async for chunk in callback.aiter():
+ tools_use = []
+ # Use server-sent-events to stream the response
+ data = json.loads(chunk)
+ if data["status"] == Status.error:
+ tools_use.append("工具调用失败:\n" + data["error"])
+ yield json.dumps({"tools": tools_use}, ensure_ascii=False)
+ yield json.dumps({"answer": "(工具调用失败,请查看工具栏报错) \n\n"}, ensure_ascii=False)
+ if data["status"] == Status.start or data["status"] == Status.complete:
+ continue
+ if data["status"] == Status.agent_action:
+ yield json.dumps({"answer": "(正在使用工具,请注意工具栏变化) \n\n"}, ensure_ascii=False)
+ if data["status"] == Status.agent_finish:
+ tools_use.append("工具名称: " + data["tool_name"])
+ tools_use.append("工具输入: " + data["input_str"])
+ tools_use.append("工具输出: " + data["output_str"])
+ yield json.dumps({"tools": tools_use}, ensure_ascii=False)
+ yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
+
+ else:
+ pass
+ # agent必须要steram=True,这部分暂时没有完成
+ # result = []
+ # async for chunk in callback.aiter():
+ # data = json.loads(chunk)
+ # status = data["status"]
+ # if status == Status.start:
+ # result.append(chunk)
+ # elif status == Status.running:
+ # result[-1]["llm_token"] += chunk["llm_token"]
+ # elif status == Status.complete:
+ # result[-1]["status"] = Status.complete
+ # elif status == Status.agent_finish:
+ # result.append(chunk)
+ # elif status == Status.agent_finish:
+ # pass
+ # yield dumps(result)
+
+ await task
+
+ return StreamingResponse(agent_chat_iterator(query=query,
+ history=history,
+ model_name=model_name,
+ prompt_name=prompt_name),
+ media_type="text/event-stream")
diff --git a/server/chat/chat.py b/server/chat/chat.py
index c025c3c2..6d3c9ce5 100644
--- a/server/chat/chat.py
+++ b/server/chat/chat.py
@@ -1,15 +1,15 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
-from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE
-from server.chat.utils import wrap_done
-from langchain.chat_models import ChatOpenAI
-from langchain import LLMChain
+from configs import LLM_MODEL, TEMPERATURE
+from server.utils import wrap_done, get_ChatOpenAI
+from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List
from server.chat.utils import History
+from server.utils import get_prompt_template
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
@@ -21,29 +21,26 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
+ temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
+ prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
model_name: str = LLM_MODEL,
+ prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
-
- model = ChatOpenAI(
- streaming=True,
- verbose=True,
- callbacks=[callback],
- openai_api_key=llm_model_dict[model_name]["api_key"],
- openai_api_base=llm_model_dict[model_name]["api_base_url"],
+ model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
- openai_proxy=llm_model_dict[model_name].get("openai_proxy")
+ callbacks=[callback],
)
- input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
+ prompt_template = get_prompt_template(prompt_name)
+ input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
@@ -66,5 +63,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
await task
- return StreamingResponse(chat_iterator(query, history, model_name),
+ return StreamingResponse(chat_iterator(query=query,
+ history=history,
+ model_name=model_name,
+ prompt_name=prompt_name),
media_type="text/event-stream")
diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py
index b26f2cbc..9c70ee59 100644
--- a/server/chat/knowledge_base_chat.py
+++ b/server/chat/knowledge_base_chat.py
@@ -1,12 +1,9 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
-from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
- VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
- TEMPERATURE)
-from server.chat.utils import wrap_done
-from server.utils import BaseResponse
-from langchain.chat_models import ChatOpenAI
-from langchain import LLMChain
+from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
+from server.utils import wrap_done, get_ChatOpenAI
+from server.utils import BaseResponse, get_prompt_template
+from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
@@ -33,7 +30,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
+ temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
):
@@ -44,27 +42,22 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(query: str,
- kb: KBService,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
+ prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
-
- model = ChatOpenAI(
- streaming=True,
- verbose=True,
- callbacks=[callback],
- openai_api_key=llm_model_dict[model_name]["api_key"],
- openai_api_base=llm_model_dict[model_name]["api_base_url"],
+ model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
- openai_proxy=llm_model_dict[model_name].get("openai_proxy")
+ callbacks=[callback],
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
- input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
+ prompt_template = get_prompt_template(prompt_name)
+ input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
@@ -102,5 +95,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
await task
- return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
+ return StreamingResponse(knowledge_base_chat_iterator(query=query,
+ top_k=top_k,
+ history=history,
+ model_name=model_name,
+ prompt_name=prompt_name),
media_type="text/event-stream")
diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py
index 857ac979..4a46ddd9 100644
--- a/server/chat/openai_chat.py
+++ b/server/chat/openai_chat.py
@@ -1,7 +1,8 @@
from fastapi.responses import StreamingResponse
from typing import List
import openai
-from configs.model_config import llm_model_dict, LLM_MODEL, logger, log_verbose
+from configs import LLM_MODEL, logger, log_verbose
+from server.utils import get_model_worker_config, fschat_openai_api_address
from pydantic import BaseModel
@@ -23,9 +24,10 @@ class OpenAiChatMsgIn(BaseModel):
async def openai_chat(msg: OpenAiChatMsgIn):
- openai.api_key = llm_model_dict[LLM_MODEL]["api_key"]
+ config = get_model_worker_config(msg.model)
+ openai.api_key = config.get("api_key", "EMPTY")
print(f"{openai.api_key=}")
- openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"]
+ openai.api_base = config.get("api_base_url", fschat_openai_api_address())
print(f"{openai.api_base=}")
print(msg)
diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py
index f8e4ebe9..00708b72 100644
--- a/server/chat/search_engine_chat.py
+++ b/server/chat/search_engine_chat.py
@@ -1,14 +1,12 @@
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
-from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
+from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
+ LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE)
from fastapi import Body
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
-from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K,
- PROMPT_TEMPLATE, TEMPERATURE)
-from server.chat.utils import wrap_done
-from server.utils import BaseResponse
-from langchain.chat_models import ChatOpenAI
-from langchain import LLMChain
+from server.utils import wrap_done, get_ChatOpenAI
+from server.utils import BaseResponse, get_prompt_template
+from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
@@ -73,7 +71,8 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
+ temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@@ -88,23 +87,20 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
+ prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
- model = ChatOpenAI(
- streaming=True,
- verbose=True,
- callbacks=[callback],
- openai_api_key=llm_model_dict[model_name]["api_key"],
- openai_api_base=llm_model_dict[model_name]["api_base_url"],
+ model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
- openai_proxy=llm_model_dict[model_name].get("openai_proxy")
+ callbacks=[callback],
)
docs = await lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
- input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
+ prompt_template = get_prompt_template(prompt_name)
+ input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
@@ -135,5 +131,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
ensure_ascii=False)
await task
- return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
+ return StreamingResponse(search_engine_chat_iterator(query=query,
+ search_engine_name=search_engine_name,
+ top_k=top_k,
+ history=history,
+ model_name=model_name,
+ prompt_name=prompt_name),
media_type="text/event-stream")
diff --git a/server/chat/utils.py b/server/chat/utils.py
index a80648b1..dd3c3332 100644
--- a/server/chat/utils.py
+++ b/server/chat/utils.py
@@ -1,22 +1,7 @@
-import asyncio
-from typing import Awaitable, List, Tuple, Dict, Union
from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose
-
-
-async def wrap_done(fn: Awaitable, event: asyncio.Event):
- """Wrap an awaitable with a event to signal when it's done or an exception is raised."""
- try:
- await fn
- except Exception as e:
- # TODO: handle exception
- msg = f"Caught exception: {e}"
- logger.error(f'{e.__class__.__name__}: {msg}',
- exc_info=e if log_verbose else None)
- finally:
- # Signal the aiter to stop.
- event.set()
+from typing import List, Tuple, Dict, Union
class History(BaseModel):
diff --git a/server/db/base.py b/server/db/base.py
index 1d911c05..ae42ac09 100644
--- a/server/db/base.py
+++ b/server/db/base.py
@@ -2,7 +2,7 @@ from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
-from configs.model_config import SQLALCHEMY_DATABASE_URI
+from configs import SQLALCHEMY_DATABASE_URI
import json
diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py
index c7b703e9..f50d8a73 100644
--- a/server/knowledge_base/kb_api.py
+++ b/server/knowledge_base/kb_api.py
@@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_base_repository import list_kbs_from_db
-from configs.model_config import EMBEDDING_MODEL, logger, log_verbose
+from configs import EMBEDDING_MODEL, logger, log_verbose
from fastapi import Body
diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py
index f3e6d654..59426fa6 100644
--- a/server/knowledge_base/kb_cache/base.py
+++ b/server/knowledge_base/kb_cache/base.py
@@ -4,9 +4,9 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
import threading
-from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE,
- embedding_model_dict, logger, log_verbose)
-from server.utils import embedding_device
+from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM,
+ logger, log_verbose)
+from server.utils import embedding_device, get_model_path
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
@@ -22,7 +22,11 @@ class ThreadSafeObject:
def __repr__(self) -> str:
cls = type(self).__name__
- return f"<{cls}: key: {self._key}, obj: {self._obj}>"
+ return f"<{cls}: key: {self.key}, obj: {self._obj}>"
+
+ @property
+ def key(self):
+ return self._key
@contextmanager
def acquire(self, owner: str = "", msg: str = ""):
@@ -30,13 +34,13 @@ class ThreadSafeObject:
try:
self._lock.acquire()
if self._pool is not None:
- self._pool._cache.move_to_end(self._key)
+ self._pool._cache.move_to_end(self.key)
if log_verbose:
- logger.info(f"{owner} 开始操作:{self._key}。{msg}")
+ logger.info(f"{owner} 开始操作:{self.key}。{msg}")
yield self._obj
finally:
if log_verbose:
- logger.info(f"{owner} 结束操作:{self._key}。{msg}")
+ logger.info(f"{owner} 结束操作:{self.key}。{msg}")
self._lock.release()
def start_loading(self):
@@ -118,15 +122,24 @@ class EmbeddingsPool(CachePool):
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
- embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
+ embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
+ if 'zh' in model:
+ # for chinese model
+ query_instruction = "为这个句子生成表示以用于检索相关文章:"
+ elif 'en' in model:
+ # for english model
+ query_instruction = "Represent this sentence for searching relevant passages:"
+ else:
+ # maybe ReRanker or else, just use empty string instead
+ query_instruction = ""
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
- model_kwargs={'device': device},
- query_instruction="为这个句子生成表示以用于检索相关文章:")
+ model_kwargs={'device': device},
+ query_instruction=query_instruction)
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
- embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
+ embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()
else:
diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py
index 325c7bb1..801e4a6b 100644
--- a/server/knowledge_base/kb_cache/faiss_cache.py
+++ b/server/knowledge_base/kb_cache/faiss_cache.py
@@ -7,7 +7,7 @@ import os
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
- return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
+ return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
@@ -17,7 +17,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
- logger.info(f"已将向量库 {self._key} 保存到磁盘")
+ logger.info(f"已将向量库 {self.key} 保存到磁盘")
return ret
def clear(self):
@@ -27,7 +27,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
- logger.info(f"已将向量库 {self._key} 清空")
+ logger.info(f"已将向量库 {self.key} 清空")
return ret
@@ -58,21 +58,22 @@ class _FaissPool(CachePool):
class KBFaissPool(_FaissPool):
def load_vector_store(
- self,
- kb_name: str,
- create: bool = True,
- embed_model: str = EMBEDDING_MODEL,
- embed_device: str = embedding_device(),
+ self,
+ kb_name: str,
+ vector_name: str = "vector_store",
+ create: bool = True,
+ embed_model: str = EMBEDDING_MODEL,
+ embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
- cache = self.get(kb_name)
+ cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None:
- item = ThreadSafeFaiss(kb_name, pool=self)
- self.set(kb_name, item)
+ item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
+ self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
- logger.info(f"loading vector store in '{kb_name}' from disk.")
- vs_path = get_vs_path(kb_name)
+ logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
+ vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
@@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool):
item.finish_loading()
else:
self.atomic.release()
- return self.get(kb_name)
+ return self.get((kb_name, vector_name))
class MemoFaissPool(_FaissPool):
@@ -144,7 +145,7 @@ if __name__ == "__main__":
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear()
-
+
threads = []
for n in range(1, 30):
t = threading.Thread(target=worker,
@@ -152,6 +153,6 @@ if __name__ == "__main__":
daemon=True)
t.start()
threads.append(t)
-
+
for t in threads:
t.join()
diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py
index 02ad222c..75d93305 100644
--- a/server/knowledge_base/kb_doc_api.py
+++ b/server/knowledge_base/kb_doc_api.py
@@ -1,10 +1,10 @@
import os
import urllib
from fastapi import File, Form, Body, Query, UploadFile
-from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
- VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
- CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
- logger, log_verbose,)
+from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
+ VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
+ CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
+ logger, log_verbose,)
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
files2docs_in_thread, KnowledgeFile)
@@ -122,10 +122,10 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件,
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
- chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
- chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
- zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
- docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
+ chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
+ chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
+ zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
+ docs: Json = Form({}, description="自定义的docs,需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
) -> BaseResponse:
'''
@@ -205,12 +205,12 @@ def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
- file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
+ file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
- docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
+ docs: Json = Body({}, description="自定义的docs,需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
) -> BaseResponse:
'''
@@ -323,6 +323,7 @@ def recreate_vector_store(
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
+ not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
):
'''
recreate vector store from the content.
@@ -366,5 +367,7 @@ def recreate_vector_store(
"msg": msg,
})
i += 1
+ if not not_refresh_vs_cache:
+ kb.save_vector_store()
return StreamingResponse(output(), media_type="text/event-stream")
diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py
index c97f8cce..a725a78e 100644
--- a/server/knowledge_base/kb_service/base.py
+++ b/server/knowledge_base/kb_service/base.py
@@ -18,8 +18,8 @@ from server.db.repository.knowledge_file_repository import (
list_docs_from_db,
)
-from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
- EMBEDDING_MODEL)
+from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
+ EMBEDDING_MODEL)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder,
diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py
index 6e20acf6..e37c93eb 100644
--- a/server/knowledge_base/kb_service/faiss_kb_service.py
+++ b/server/knowledge_base/kb_service/faiss_kb_service.py
@@ -1,7 +1,7 @@
import os
import shutil
-from configs.model_config import (
+from configs import (
KB_ROOT_PATH,
SCORE_THRESHOLD,
logger, log_verbose,
@@ -18,18 +18,21 @@ from server.utils import torch_gc
class FaissKBService(KBService):
vs_path: str
kb_path: str
+ vector_name: str = "vector_store"
def vs_type(self) -> str:
return SupportedVSType.FAISS
def get_vs_path(self):
- return os.path.join(self.get_kb_path(), "vector_store")
+ return os.path.join(self.get_kb_path(), self.vector_name)
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss:
- return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
+ return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
+ vector_name=self.vector_name,
+ embed_model=self.embed_model)
def save_vector_store(self):
self.load_vector_store().save(self.vs_path)
diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py
index 5ca425b5..80eac4e1 100644
--- a/server/knowledge_base/kb_service/milvus_kb_service.py
+++ b/server/knowledge_base/kb_service/milvus_kb_service.py
@@ -7,7 +7,7 @@ from langchain.schema import Document
from langchain.vectorstores import Milvus
from sklearn.preprocessing import normalize
-from configs.model_config import SCORE_THRESHOLD, kbs_config
+from configs import SCORE_THRESHOLD, kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
@@ -22,9 +22,9 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
- def save_vector_store(self):
- if self.milvus.col:
- self.milvus.col.flush()
+ # def save_vector_store(self):
+ # if self.milvus.col:
+ # self.milvus.col.flush()
def get_doc_by_id(self, id: str) -> Optional[Document]:
if self.milvus.col:
diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py
index fa832ab5..9c17e80d 100644
--- a/server/knowledge_base/kb_service/pg_kb_service.py
+++ b/server/knowledge_base/kb_service/pg_kb_service.py
@@ -7,7 +7,7 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text
-from configs.model_config import EMBEDDING_DEVICE, kbs_config
+from configs import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py
index b2073a10..2ca144af 100644
--- a/server/knowledge_base/migrate.py
+++ b/server/knowledge_base/migrate.py
@@ -1,9 +1,10 @@
-from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
- logger, log_verbose)
+from configs import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
+ CHUNK_SIZE, OVERLAP_SIZE,
+ logger, log_verbose)
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder,files2docs_in_thread,
KnowledgeFile,)
-from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
+from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.base import Base, engine
import os
@@ -33,33 +34,23 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
def folder2db(
- kb_name: str,
- mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
+ kb_names: List[str],
+ mode: Literal["recreate_vs", "update_in_db", "increament"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
- chunk_size: int = -1,
- chunk_overlap: int = -1,
+ chunk_size: int = CHUNK_SIZE,
+ chunk_overlap: int = CHUNK_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
):
'''
use existed files in local folder to populate database and/or vector store.
set parameter `mode` to:
recreate_vs: recreate all vector store and fill info to database using existed files in local folder
- fill_info_only: do not create vector store, fill info to db using existed files only
+ fill_info_only(disabled): do not create vector store, fill info to db using existed files only
update_in_db: update vector store and database info using local files that existed in database only
increament: create vector store and database info for local files that not existed in database only
'''
- kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
- kb.create_kb()
-
- if mode == "recreate_vs":
- files_count = kb.count_files()
- print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。")
- kb.clear_vs()
- files_count = kb.count_files()
- print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。")
-
- kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
+ def files2vs(kb_name: str, kb_files: List[KnowledgeFile]):
for success, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
@@ -68,84 +59,77 @@ def folder2db(
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
- kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
+ kb_file.splited_docs = docs
+ kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True)
else:
print(result)
- kb.save_vector_store()
- elif mode == "fill_info_only":
- files = list_files_from_folder(kb_name)
- kb_files = file_to_kbfile(kb_name, files)
- for kb_file in kb_files:
- add_file_to_db(kb_file)
- print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
- elif mode == "update_in_db":
- files = kb.list_files()
- kb_files = file_to_kbfile(kb_name, files)
+ kb_names = kb_names or list_kbs_from_folder()
+ for kb_name in kb_names:
+ kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
+ kb.create_kb()
- for kb_file in kb_files:
- kb.update_doc(kb_file, not_refresh_vs_cache=True)
- kb.save_vector_store()
- elif mode == "increament":
- db_files = kb.list_files()
- folder_files = list_files_from_folder(kb_name)
- files = list(set(folder_files) - set(db_files))
- kb_files = file_to_kbfile(kb_name, files)
-
- for success, result in files2docs_in_thread(kb_files,
- chunk_size=chunk_size,
- chunk_overlap=chunk_overlap,
- zh_title_enhance=zh_title_enhance):
- if success:
- _, filename, docs = result
- print(f"正在将 {kb_name}/{filename} 添加到向量库")
- kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
- kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
- else:
- print(result)
- kb.save_vector_store()
- else:
- print(f"unspported migrate mode: {mode}")
+ # 清除向量库,从本地文件重建
+ if mode == "recreate_vs":
+ kb.clear_vs()
+ kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
+ files2vs(kb_name, kb_files)
+ kb.save_vector_store()
+ # # 不做文件内容的向量化,仅将文件元信息存到数据库
+ # # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。
+ # elif mode == "fill_info_only":
+ # files = list_files_from_folder(kb_name)
+ # kb_files = file_to_kbfile(kb_name, files)
+ # for kb_file in kb_files:
+ # add_file_to_db(kb_file)
+ # print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
+ # 以数据库中文件列表为基准,利用本地文件更新向量库
+ elif mode == "update_in_db":
+ files = kb.list_files()
+ kb_files = file_to_kbfile(kb_name, files)
+ files2vs(kb_name, kb_files)
+ kb.save_vector_store()
+ # 对比本地目录与数据库中的文件列表,进行增量向量化
+ elif mode == "increament":
+ db_files = kb.list_files()
+ folder_files = list_files_from_folder(kb_name)
+ files = list(set(folder_files) - set(db_files))
+ kb_files = file_to_kbfile(kb_name, files)
+ files2vs(kb_name, kb_files)
+ kb.save_vector_store()
+ else:
+ print(f"unspported migrate mode: {mode}")
-def recreate_all_vs(
- vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
- embed_mode: str = EMBEDDING_MODEL,
- **kwargs: Any,
-):
+def prune_db_docs(kb_names: List[str]):
'''
- used to recreate a vector store or change current vector store to another type or embed_model
+ delete docs in database that not existed in local folder.
+ it is used to delete database docs after user deleted some doc files in file browser
'''
- for kb_name in list_kbs_from_folder():
- folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs)
+ for kb_name in kb_names:
+ kb = KBServiceFactory.get_service_by_name(kb_name)
+ if kb and kb.exists():
+ files_in_db = kb.list_files()
+ files_in_folder = list_files_from_folder(kb_name)
+ files = list(set(files_in_db) - set(files_in_folder))
+ kb_files = file_to_kbfile(kb_name, files)
+ for kb_file in kb_files:
+ kb.delete_doc(kb_file, not_refresh_vs_cache=True)
+ print(f"success to delete docs for file: {kb_name}/{kb_file.filename}")
+ kb.save_vector_store()
-def prune_db_files(kb_name: str):
- '''
- delete files in database that not existed in local folder.
- it is used to delete database files after user deleted some doc files in file browser
- '''
- kb = KBServiceFactory.get_service_by_name(kb_name)
- if kb.exists():
- files_in_db = kb.list_files()
- files_in_folder = list_files_from_folder(kb_name)
- files = list(set(files_in_db) - set(files_in_folder))
- kb_files = file_to_kbfile(kb_name, files)
- for kb_file in kb_files:
- kb.delete_doc(kb_file, not_refresh_vs_cache=True)
- kb.save_vector_store()
- return kb_files
-
-def prune_folder_files(kb_name: str):
+def prune_folder_files(kb_names: List[str]):
'''
delete doc files in local folder that not existed in database.
is is used to free local disk space by delete unused doc files.
'''
- kb = KBServiceFactory.get_service_by_name(kb_name)
- if kb.exists():
- files_in_db = kb.list_files()
- files_in_folder = list_files_from_folder(kb_name)
- files = list(set(files_in_folder) - set(files_in_db))
- for file in files:
- os.remove(get_file_path(kb_name, file))
- return files
+ for kb_name in kb_names:
+ kb = KBServiceFactory.get_service_by_name(kb_name)
+ if kb and kb.exists():
+ files_in_db = kb.list_files()
+ files_in_folder = list_files_from_folder(kb_name)
+ files = list(set(files_in_folder) - set(files_in_db))
+ for file in files:
+ os.remove(get_file_path(kb_name, file))
+ print(f"success to delete file: {kb_name}/{file}")
diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py
index 8eeb5049..02212c88 100644
--- a/server/knowledge_base/utils.py
+++ b/server/knowledge_base/utils.py
@@ -2,18 +2,17 @@ import os
from transformers import AutoTokenizer
-from configs.model_config import (
+from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
- logger,
- log_verbose,
- text_splitter_dict,
- llm_model_dict,
- LLM_MODEL,
- TEXT_SPLITTER
+ logger,
+ log_verbose,
+ text_splitter_dict,
+ LLM_MODEL,
+ TEXT_SPLITTER_NAME,
)
import importlib
from text_splitter import zh_title_enhance as func_zh_title_enhance
@@ -23,7 +22,7 @@ from langchain.text_splitter import TextSplitter
from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor
-from server.utils import run_in_thread_pool, embedding_device
+from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
import chardet
@@ -44,8 +43,8 @@ def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
-def get_vs_path(knowledge_base_name: str):
- return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
+def get_vs_path(knowledge_base_name: str, vector_name: str):
+ return os.path.join(get_kb_path(knowledge_base_name), vector_name)
def get_file_path(knowledge_base_name: str, doc_name: str):
@@ -190,9 +189,10 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
def make_text_splitter(
- splitter_name: str = TEXT_SPLITTER,
+ splitter_name: str = TEXT_SPLITTER_NAME,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
+ llm_model: str = LLM_MODEL,
):
"""
根据参数获取特定的分词器
@@ -228,8 +228,9 @@ def make_text_splitter(
)
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
+ config = get_model_worker_config(llm_model)
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
- llm_model_dict[LLM_MODEL]["local_model_path"]
+ config.get("model_path")
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
from transformers import GPT2TokenizerFast
@@ -281,7 +282,7 @@ class KnowledgeFile:
self.docs = None
self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext)
- self.text_splitter_name = TEXT_SPLITTER
+ self.text_splitter_name = TEXT_SPLITTER_NAME
def file2docs(self, refresh: bool=False):
if self.docs is None or refresh:
@@ -372,18 +373,23 @@ def files2docs_in_thread(
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
- if isinstance(file, tuple) and len(file) >= 2:
- file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
- elif isinstance(file, dict):
- filename = file.pop("filename")
- kb_name = file.pop("kb_name")
- kwargs = file
- file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
- kwargs["file"] = file
- kwargs["chunk_size"] = chunk_size
- kwargs["chunk_overlap"] = chunk_overlap
- kwargs["zh_title_enhance"] = zh_title_enhance
- kwargs_list.append(kwargs)
+ try:
+ if isinstance(file, tuple) and len(file) >= 2:
+ filename=file[0]
+ kb_name=file[1]
+ file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
+ elif isinstance(file, dict):
+ filename = file.pop("filename")
+ kb_name = file.pop("kb_name")
+ kwargs.update(file)
+ file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
+ kwargs["file"] = file
+ kwargs["chunk_size"] = chunk_size
+ kwargs["chunk_overlap"] = chunk_overlap
+ kwargs["zh_title_enhance"] = zh_title_enhance
+ kwargs_list.append(kwargs)
+ except Exception as e:
+ yield False, (kb_name, filename, str(e))
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
yield result
@@ -398,4 +404,4 @@ if __name__ == "__main__":
pprint(docs[-1])
docs = kb_file.file2text()
- pprint(docs[-1])
\ No newline at end of file
+ pprint(docs[-1])
diff --git a/server/llm_api.py b/server/llm_api.py
index 5843e89a..b028747b 100644
--- a/server/llm_api.py
+++ b/server/llm_api.py
@@ -1,10 +1,10 @@
from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
-from server.utils import BaseResponse, fschat_controller_address
-import httpx
+from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client
-def list_llm_models(
+
+def list_running_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
placeholder: str = Body(None, description="该参数未使用,占位用"),
) -> BaseResponse:
@@ -13,8 +13,9 @@ def list_llm_models(
'''
try:
controller_address = controller_address or fschat_controller_address()
- r = httpx.post(controller_address + "/list_models")
- return BaseResponse(data=r.json()["models"])
+ with get_httpx_client() as client:
+ r = client.post(controller_address + "/list_models")
+ return BaseResponse(data=r.json()["models"])
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
@@ -24,6 +25,13 @@ def list_llm_models(
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
+def list_config_models() -> BaseResponse:
+ '''
+ 从本地获取configs中配置的模型列表
+ '''
+ return BaseResponse(data=list_llm_models())
+
+
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
@@ -34,11 +42,12 @@ def stop_llm_model(
'''
try:
controller_address = controller_address or fschat_controller_address()
- r = httpx.post(
- controller_address + "/release_worker",
- json={"model_name": model_name},
- )
- return r.json()
+ with get_httpx_client() as client:
+ r = client.post(
+ controller_address + "/release_worker",
+ json={"model_name": model_name},
+ )
+ return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
@@ -57,12 +66,13 @@ def change_llm_model(
'''
try:
controller_address = controller_address or fschat_controller_address()
- r = httpx.post(
- controller_address + "/release_worker",
- json={"model_name": model_name, "new_model_name": new_model_name},
- timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
- )
- return r.json()
+ with get_httpx_client() as client:
+ r = client.post(
+ controller_address + "/release_worker",
+ json={"model_name": model_name, "new_model_name": new_model_name},
+ timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
+ )
+ return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py
index a3a162f3..c1a824bd 100644
--- a/server/model_workers/__init__.py
+++ b/server/model_workers/__init__.py
@@ -2,3 +2,5 @@ from .zhipu import ChatGLMWorker
from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker
from .qianfan import QianFanWorker
+from .fangzhou import FangZhouWorker
+from .qwen import QwenWorker
diff --git a/server/model_workers/base.py b/server/model_workers/base.py
index df5fbfcc..515c5db9 100644
--- a/server/model_workers/base.py
+++ b/server/model_workers/base.py
@@ -1,4 +1,4 @@
-from configs.model_config import LOG_PATH
+from configs.basic_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker
@@ -92,5 +92,5 @@ class ApiModelWorker(BaseModelWorker):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
- raise RuntimeError(f"unknow role in msg: {msg}")
+ raise RuntimeError(f"unknown role in msg: {msg}")
return result
diff --git a/server/model_workers/fangzhou.py b/server/model_workers/fangzhou.py
new file mode 100644
index 00000000..33a6b7da
--- /dev/null
+++ b/server/model_workers/fangzhou.py
@@ -0,0 +1,122 @@
+from server.model_workers.base import ApiModelWorker
+from configs.model_config import TEMPERATURE
+from fastchat import conversation as conv
+import sys
+import json
+from pprint import pprint
+from server.utils import get_model_worker_config
+from typing import List, Literal, Dict
+
+
+def request_volc_api(
+ messages: List[Dict],
+ model_name: str = "fangzhou-api",
+ version: str = "chatglm-6b-model",
+ temperature: float = TEMPERATURE,
+ api_key: str = None,
+ secret_key: str = None,
+):
+ from volcengine.maas import MaasService, MaasException, ChatRole
+
+ maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
+ config = get_model_worker_config(model_name)
+ version = version or config.get("version")
+ version_url = config.get("version_url")
+ api_key = api_key or config.get("api_key")
+ secret_key = secret_key or config.get("secret_key")
+
+ maas.set_ak(api_key)
+ maas.set_sk(secret_key)
+
+ # document: "https://www.volcengine.com/docs/82379/1099475"
+ req = {
+ "model": {
+ "name": version,
+ },
+ "parameters": {
+ # 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
+ "max_new_tokens": 1000,
+ "temperature": temperature,
+ },
+ "messages": messages,
+ }
+
+ try:
+ resps = maas.stream_chat(req)
+ for resp in resps:
+ yield resp
+ except MaasException as e:
+ print(e)
+
+
+class FangZhouWorker(ApiModelWorker):
+ """
+ 火山方舟
+ """
+ SUPPORT_MODELS = ["chatglm-6b-model"]
+
+ def __init__(
+ self,
+ *,
+ version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
+ model_names: List[str] = ["fangzhou-api"],
+ controller_addr: str,
+ worker_addr: str,
+ **kwargs,
+ ):
+ kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
+ kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小
+
+ super().__init__(**kwargs)
+
+ config = self.get_config()
+ self.version = version
+ self.api_key = config.get("api_key")
+ self.secret_key = config.get("secret_key")
+
+ self.conv = conv.Conversation(
+ name=self.model_names[0],
+ system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
+ messages=[],
+ roles=["user", "assistant", "system"],
+ sep="\n### ",
+ stop_str="###",
+ )
+
+ def generate_stream_gate(self, params):
+ super().generate_stream_gate(params)
+
+ messages = self.prompt_to_messages(params["prompt"])
+ text = ""
+
+ for resp in request_volc_api(messages=messages,
+ model_name=self.model_names[0],
+ version=self.version,
+ temperature=params.get("temperature", TEMPERATURE),
+ ):
+ error = resp.error
+ if error.code_n > 0:
+ data = {"error_code": error.code_n, "text": error.message}
+ elif chunk := resp.choice.message.content:
+ text += chunk
+ data = {"error_code": 0, "text": text}
+ yield json.dumps(data, ensure_ascii=False).encode() + b"\0"
+
+ def get_embeddings(self, params):
+ # TODO: 支持embeddings
+ print("embedding")
+ print(params)
+
+
+if __name__ == "__main__":
+ import uvicorn
+ from server.utils import MakeFastAPIOffline
+ from fastchat.serve.model_worker import app
+
+ worker = FangZhouWorker(
+ controller_addr="http://127.0.0.1:20001",
+ worker_addr="http://127.0.0.1:21005",
+ )
+ sys.modules["fastchat.serve.model_worker"].worker = worker
+ MakeFastAPIOffline(app)
+ uvicorn.run(app, port=21005)
diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py
index c772c0dc..9079ea44 100644
--- a/server/model_workers/minimax.py
+++ b/server/model_workers/minimax.py
@@ -2,7 +2,7 @@ from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
-import httpx
+from server.utils import get_httpx_client
from pprint import pprint
from typing import List, Dict
@@ -63,22 +63,23 @@ class MiniMaxWorker(ApiModelWorker):
}
print("request data sent to minimax:")
pprint(data)
- response = httpx.stream("POST",
- self.BASE_URL.format(pro=pro, group_id=group_id),
- headers=headers,
- json=data)
- with response as r:
- text = ""
- for e in r.iter_text():
- if e.startswith("data: "): # 真是优秀的返回
- data = json.loads(e[6:])
- if not data.get("usage"):
- if choices := data.get("choices"):
- chunk = choices[0].get("delta", "").strip()
- if chunk:
- print(chunk)
- text += chunk
- yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
+ with get_httpx_client() as client:
+ response = client.stream("POST",
+ self.BASE_URL.format(pro=pro, group_id=group_id),
+ headers=headers,
+ json=data)
+ with response as r:
+ text = ""
+ for e in r.iter_text():
+ if e.startswith("data: "): # 真是优秀的返回
+ data = json.loads(e[6:])
+ if not data.get("usage"):
+ if choices := data.get("choices"):
+ chunk = choices[0].get("delta", "").strip()
+ if chunk:
+ print(chunk)
+ text += chunk
+ yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
@@ -93,8 +94,8 @@ if __name__ == "__main__":
worker = MiniMaxWorker(
controller_addr="http://127.0.0.1:20001",
- worker_addr="http://127.0.0.1:20004",
+ worker_addr="http://127.0.0.1:21002",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
- uvicorn.run(app, port=20003)
+ uvicorn.run(app, port=21002)
diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py
index 8a593a7e..5eefd407 100644
--- a/server/model_workers/qianfan.py
+++ b/server/model_workers/qianfan.py
@@ -5,7 +5,7 @@ import sys
import json
import httpx
from cachetools import cached, TTLCache
-from server.utils import get_model_worker_config
+from server.utils import get_model_worker_config, get_httpx_client
from typing import List, Literal, Dict
@@ -54,7 +54,8 @@ def get_baidu_access_token(api_key: str, secret_key: str) -> str:
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
try:
- return httpx.get(url, params=params).json().get("access_token")
+ with get_httpx_client() as client:
+ return client.get(url, params=params).json().get("access_token")
except Exception as e:
print(f"failed to get token from baidu: {e}")
@@ -72,7 +73,10 @@ def request_qianfan_api(
version_url = config.get("version_url")
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
if not access_token:
- raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
+ yield {
+ "error_code": 403,
+ "error_msg": f"failed to get access token. have you set the correct api_key and secret key?",
+ }
url = BASE_URL.format(
model_version=version_url or MODEL_VERSIONS[version],
@@ -88,14 +92,15 @@ def request_qianfan_api(
'Accept': 'application/json',
}
- with httpx.stream("POST", url, headers=headers, json=payload) as response:
- for line in response.iter_lines():
- if not line.strip():
- continue
- if line.startswith("data: "):
- line = line[6:]
- resp = json.loads(line)
- yield resp
+ with get_httpx_client() as client:
+ with client.stream("POST", url, headers=headers, json=payload) as response:
+ for line in response.iter_lines():
+ if not line.strip():
+ continue
+ if line.startswith("data: "):
+ line = line[6:]
+ resp = json.loads(line)
+ yield resp
class QianFanWorker(ApiModelWorker):
@@ -165,8 +170,8 @@ if __name__ == "__main__":
worker = QianFanWorker(
controller_addr="http://127.0.0.1:20001",
- worker_addr="http://127.0.0.1:20006",
+ worker_addr="http://127.0.0.1:21004"
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
- uvicorn.run(app, port=20006)
\ No newline at end of file
+ uvicorn.run(app, port=21004)
\ No newline at end of file
diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py
new file mode 100644
index 00000000..32d87574
--- /dev/null
+++ b/server/model_workers/qwen.py
@@ -0,0 +1,123 @@
+import json
+import sys
+from configs import TEMPERATURE
+from http import HTTPStatus
+from typing import List, Literal, Dict
+
+from fastchat import conversation as conv
+
+from server.model_workers.base import ApiModelWorker
+from server.utils import get_model_worker_config
+
+
+def request_qwen_api(
+ messages: List[Dict[str, str]],
+ api_key: str = None,
+ version: str = "qwen-turbo",
+ temperature: float = TEMPERATURE,
+ model_name: str = "qwen-api",
+):
+ import dashscope
+
+ config = get_model_worker_config(model_name)
+ api_key = api_key or config.get("api_key")
+ version = version or config.get("version")
+
+ gen = dashscope.Generation()
+ responses = gen.call(
+ model=version,
+ temperature=temperature,
+ api_key=api_key,
+ messages=messages,
+ result_format='message', # set the result is message format.
+ stream=True,
+ )
+
+ text = ""
+ for resp in responses:
+ if resp.status_code != HTTPStatus.OK:
+ yield {
+ "code": resp.status_code,
+ "text": "api not response correctly",
+ }
+
+ if resp["status_code"] == 200:
+ if choices := resp["output"]["choices"]:
+ yield {
+ "code": 200,
+ "text": choices[0]["message"]["content"],
+ }
+ else:
+ yield {
+ "code": resp["status_code"],
+ "text": resp["message"],
+ }
+
+
+class QwenWorker(ApiModelWorker):
+ def __init__(
+ self,
+ *,
+ version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
+ model_names: List[str] = ["qwen-api"],
+ controller_addr: str,
+ worker_addr: str,
+ **kwargs,
+ ):
+ kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
+ kwargs.setdefault("context_len", 16384)
+ super().__init__(**kwargs)
+
+ # TODO: 确认模板是否需要修改
+ self.conv = conv.Conversation(
+ name=self.model_names[0],
+ system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
+ messages=[],
+ roles=["user", "assistant", "system"],
+ sep="\n### ",
+ stop_str="###",
+ )
+ config = self.get_config()
+ self.api_key = config.get("api_key")
+ self.version = version
+
+ def generate_stream_gate(self, params):
+ messages = self.prompt_to_messages(params["prompt"])
+
+ for resp in request_qwen_api(messages=messages,
+ api_key=self.api_key,
+ version=self.version,
+ temperature=params.get("temperature")):
+ if resp["code"] == 200:
+ yield json.dumps({
+ "error_code": 0,
+ "text": resp["text"]
+ },
+ ensure_ascii=False
+ ).encode() + b"\0"
+ else:
+ yield json.dumps({
+ "error_code": resp["code"],
+ "text": resp["text"]
+ },
+ ensure_ascii=False
+ ).encode() + b"\0"
+
+ def get_embeddings(self, params):
+ # TODO: 支持embeddings
+ print("embedding")
+ print(params)
+
+
+if __name__ == "__main__":
+ import uvicorn
+ from server.utils import MakeFastAPIOffline
+ from fastchat.serve.model_worker import app
+
+ worker = QwenWorker(
+ controller_addr="http://127.0.0.1:20001",
+ worker_addr="http://127.0.0.1:20007",
+ )
+ sys.modules["fastchat.serve.model_worker"].worker = worker
+ MakeFastAPIOffline(app)
+ uvicorn.run(app, port=20007)
diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py
index 499e8bc1..bc98a9cf 100644
--- a/server/model_workers/xinghuo.py
+++ b/server/model_workers/xinghuo.py
@@ -94,8 +94,8 @@ if __name__ == "__main__":
worker = XingHuoWorker(
controller_addr="http://127.0.0.1:20001",
- worker_addr="http://127.0.0.1:20005",
+ worker_addr="http://127.0.0.1:21003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
- uvicorn.run(app, port=20005)
+ uvicorn.run(app, port=21003)
diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py
index f835ac06..18cec5b5 100644
--- a/server/model_workers/zhipu.py
+++ b/server/model_workers/zhipu.py
@@ -67,8 +67,8 @@ if __name__ == "__main__":
worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001",
- worker_addr="http://127.0.0.1:20003",
+ worker_addr="http://127.0.0.1:21001",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
- uvicorn.run(app, port=20003)
+ uvicorn.run(app, port=21001)
diff --git a/server/utils.py b/server/utils.py
index 516eaee3..b6a3945b 100644
--- a/server/utils.py
+++ b/server/utils.py
@@ -4,16 +4,57 @@ from typing import List
from fastapi import FastAPI
from pathlib import Path
import asyncio
-from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE, logger, log_verbose
-from configs.server_config import FSCHAT_MODEL_WORKERS
+from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
+ MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
+ logger, log_verbose,
+ FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
-from typing import Literal, Optional, Callable, Generator, Dict, Any
+from langchain.chat_models import ChatOpenAI
+import httpx
+from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
thread_pool = ThreadPoolExecutor(os.cpu_count())
+async def wrap_done(fn: Awaitable, event: asyncio.Event):
+ """Wrap an awaitable with a event to signal when it's done or an exception is raised."""
+ try:
+ await fn
+ except Exception as e:
+ # TODO: handle exception
+ msg = f"Caught exception: {e}"
+ logger.error(f'{e.__class__.__name__}: {msg}',
+ exc_info=e if log_verbose else None)
+ finally:
+ # Signal the aiter to stop.
+ event.set()
+
+
+def get_ChatOpenAI(
+ model_name: str,
+ temperature: float,
+ streaming: bool = True,
+ callbacks: List[Callable] = [],
+ verbose: bool = True,
+ **kwargs: Any,
+) -> ChatOpenAI:
+ config = get_model_worker_config(model_name)
+ model = ChatOpenAI(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ openai_api_key=config.get("api_key", "EMPTY"),
+ openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
+ model_name=model_name,
+ temperature=temperature,
+ openai_proxy=config.get("openai_proxy"),
+ **kwargs
+ )
+ return model
+
+
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
@@ -197,22 +238,71 @@ def MakeFastAPIOffline(
)
+# 从model_config中获取模型信息
+def list_embed_models() -> List[str]:
+ '''
+ get names of configured embedding models
+ '''
+ return list(MODEL_PATH["embed_model"])
+
+def list_llm_models() -> Dict[str, List[str]]:
+ '''
+ get names of configured llm models with different types.
+ return [(model_name, config_type), ...]
+ '''
+ workers = list(FSCHAT_MODEL_WORKERS)
+ if "default" in workers:
+ workers.remove("default")
+ return {
+ "local": list(MODEL_PATH["llm_model"]),
+ "online": list(ONLINE_LLM_MODEL),
+ "worker": workers,
+ }
+
+
+def get_model_path(model_name: str, type: str = None) -> Optional[str]:
+ if type in MODEL_PATH:
+ paths = MODEL_PATH[type]
+ else:
+ paths = {}
+ for v in MODEL_PATH.values():
+ paths.update(v)
+
+ if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
+ path = Path(path_str)
+ if path.is_dir(): # 任意绝对路径
+ return str(path)
+
+ root_path = Path(MODEL_ROOT_PATH)
+ if root_path.is_dir():
+ path = root_path / model_name
+ if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
+ return str(path)
+ path = root_path / path_str
+ if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
+ return str(path)
+ path = root_path / path_str.split("/")[-1]
+ if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
+ return str(path)
+ return path_str # THUDM/chatglm06b
+
+
# 从server_config中获取服务信息
-def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
+def get_model_worker_config(model_name: str = None) -> dict:
'''
加载model worker的配置项。
- 优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
+ 优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
+ from configs.model_config import ONLINE_LLM_MODEL
from configs.server_config import FSCHAT_MODEL_WORKERS
from server import model_workers
- from configs.model_config import llm_model_dict
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
- config.update(llm_model_dict.get(model_name, {}))
+ config.update(ONLINE_LLM_MODEL.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
- # 如果没有设置local_model_path,则认为是在线模型API
- if not config.get("local_model_path"):
+ # 在线模型API
+ if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True
if provider := config.get("provider"):
try:
@@ -222,13 +312,14 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
- config["device"] = llm_device(config.get("device") or LLM_DEVICE)
+ config["model_path"] = get_model_path(model_name)
+ config["device"] = llm_device(config.get("device"))
return config
def get_all_model_worker_configs() -> dict:
result = {}
- model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
+ model_names = set(FSCHAT_MODEL_WORKERS.keys())
for name in model_names:
if name != "default":
result[name] = get_model_worker_config(name)
@@ -256,7 +347,7 @@ def fschat_openai_api_address() -> str:
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
- return f"http://{host}:{port}"
+ return f"http://{host}:{port}/v1"
def api_address() -> str:
@@ -275,19 +366,74 @@ def webui_address() -> str:
return f"http://{host}:{port}"
-def set_httpx_timeout(timeout: float = None):
+def get_prompt_template(name: str) -> Optional[str]:
'''
- 设置httpx默认timeout。
- httpx默认timeout是5秒,在请求LLM回答时不够用。
+ 从prompt_config中加载模板内容
+ '''
+ from configs import prompt_config
+ import importlib
+ importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
+
+ return prompt_config.PROMPT_TEMPLATES.get(name)
+
+
+def set_httpx_config(
+ timeout: float = HTTPX_DEFAULT_TIMEOUT,
+ proxy: Union[str, Dict] = None,
+ ):
+ '''
+ 设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
+ 将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
+ 对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。
'''
import httpx
- from configs.server_config import HTTPX_DEFAULT_TIMEOUT
+ import os
- timeout = timeout or HTTPX_DEFAULT_TIMEOUT
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
+ # 在进程范围内设置系统级代理
+ proxies = {}
+ if isinstance(proxy, str):
+ for n in ["http", "https", "all"]:
+ proxies[n + "_proxy"] = proxy
+ elif isinstance(proxy, dict):
+ for n in ["http", "https", "all"]:
+ if p:= proxy.get(n):
+ proxies[n + "_proxy"] = p
+ elif p:= proxy.get(n + "_proxy"):
+ proxies[n + "_proxy"] = p
+
+ for k, v in proxies.items():
+ os.environ[k] = v
+
+ # set host to bypass proxy
+ no_proxy = [x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()]
+ no_proxy += [
+ # do not use proxy for locahost
+ "http://127.0.0.1",
+ "http://localhost",
+ ]
+ # do not use proxy for user deployed fastchat servers
+ for x in [
+ fschat_controller_address(),
+ fschat_model_worker_address(),
+ fschat_openai_api_address(),
+ ]:
+ host = ":".join(x.split(":")[:2])
+ if host not in no_proxy:
+ no_proxy.append(host)
+ os.environ["NO_PROXY"] = ",".join(no_proxy)
+
+ # TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。
+ # patch requests to use custom proxies instead of system settings
+ # def _get_proxies():
+ # return {}
+
+ # import urllib.request
+ # urllib.request.getproxies = _get_proxies
+
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
@@ -302,13 +448,15 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
return "cpu"
-def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
+def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
+ device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
-def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
+def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
+ device = device or EMBEDDING_DEVICE
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
@@ -333,3 +481,51 @@ def run_in_thread_pool(
for obj in as_completed(tasks):
yield obj.result()
+
+def get_httpx_client(
+ use_async: bool = False,
+ proxies: Union[str, Dict] = None,
+ timeout: float = HTTPX_DEFAULT_TIMEOUT,
+ **kwargs,
+) -> Union[httpx.Client, httpx.AsyncClient]:
+ '''
+ helper to get httpx client with default proxies that bypass local addesses.
+ '''
+ default_proxies = {
+ # do not use proxy for locahost
+ "all://127.0.0.1": None,
+ "all://localhost": None,
+ }
+ # do not use proxy for user deployed fastchat servers
+ for x in [
+ fschat_controller_address(),
+ fschat_model_worker_address(),
+ fschat_openai_api_address(),
+ ]:
+ host = ":".join(x.split(":")[:2])
+ default_proxies.update({host: None})
+
+ # get proxies from system envionrent
+ default_proxies.update({
+ "http://": os.environ.get("http_proxy"),
+ "https://": os.environ.get("https_proxy"),
+ "all://": os.environ.get("all_proxy"),
+ })
+ for host in os.environ.get("no_proxy", "").split(","):
+ if host := host.strip():
+ default_proxies.update({host: None})
+
+ # merge default proxies with user provided proxies
+ if isinstance(proxies, str):
+ proxies = {"all://": proxies}
+
+ if isinstance(proxies, dict):
+ default_proxies.update(proxies)
+
+ # construct Client
+ kwargs.update(timeout=timeout, proxies=default_proxies)
+ if use_async:
+ return httpx.AsyncClient(**kwargs)
+ else:
+ return httpx.Client(**kwargs)
+
diff --git a/startup.py b/startup.py
index b3094445..272c2f8f 100644
--- a/startup.py
+++ b/startup.py
@@ -17,12 +17,21 @@ except:
pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
-from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
- logger, log_verbose, TEXT_SPLITTER
-from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
- FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
+from configs import (
+ LOG_PATH,
+ log_verbose,
+ logger,
+ LLM_MODEL,
+ EMBEDDING_MODEL,
+ TEXT_SPLITTER_NAME,
+ FSCHAT_CONTROLLER,
+ FSCHAT_OPENAI_API,
+ API_SERVER,
+ WEBUI_SERVER,
+ HTTPX_DEFAULT_TIMEOUT,
+)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
- fschat_openai_api_address, set_httpx_timeout,
+ fschat_openai_api_address, set_httpx_config, get_httpx_client,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse
@@ -49,112 +58,162 @@ def create_controller_app(
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
+ """
+ kwargs包含的字段如下:
+ host:
+ port:
+ model_names:[`model_name`]
+ controller_address:
+ worker_address:
+
+
+ 对于online_api:
+ online_api:True
+ worker_class: `provider`
+ 对于离线模型:
+ model_path: `model_name_or_path`,huggingface的repo-id或本地路径
+ device:`LLM_DEVICE`
+ """
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
- from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
+ from fastchat.serve.model_worker import worker_id, logger
import argparse
- import threading
- import fastchat.serve.model_worker
logger.setLevel(log_level)
- # workaround to make program exit with Ctrl+c
- # it should be deleted after pr is merged by fastchat
- def _new_init_heart_beat(self):
- self.register_to_controller()
- self.heart_beat_thread = threading.Thread(
- target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
- )
- self.heart_beat_thread.start()
-
- ModelWorker.init_heart_beat = _new_init_heart_beat
-
parser = argparse.ArgumentParser()
args = parser.parse_args([])
- # default args. should be deleted after pr is merged by fastchat
- args.gpus = None
- args.max_gpu_memory = "20GiB"
- args.load_8bit = False
- args.cpu_offloading = None
- args.gptq_ckpt = None
- args.gptq_wbits = 16
- args.gptq_groupsize = -1
- args.gptq_act_order = False
- args.awq_ckpt = None
- args.awq_wbits = 16
- args.awq_groupsize = -1
- args.num_gpus = 1
- args.model_names = []
- args.conv_template = None
- args.limit_worker_concurrency = 5
- args.stream_interval = 2
- args.no_register = False
- args.embed_in_truncate = False
for k, v in kwargs.items():
setattr(args, k, v)
- if args.gpus:
- if args.num_gpus is None:
- args.num_gpus = len(args.gpus.split(','))
- if len(args.gpus.split(",")) < args.num_gpus:
- raise ValueError(
- f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
- )
- os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
-
# 在线模型API
if worker_class := kwargs.get("worker_class"):
+ from fastchat.serve.model_worker import app
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
+ sys.modules["fastchat.serve.model_worker"].worker = worker
# 本地模型
else:
- # workaround to make program exit with Ctrl+c
- # it should be deleted after pr is merged by fastchat
- def _new_init_heart_beat(self):
- self.register_to_controller()
- self.heart_beat_thread = threading.Thread(
- target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
+ from configs.model_config import VLLM_MODEL_DICT
+ if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
+ import fastchat.serve.vllm_worker
+ from fastchat.serve.vllm_worker import VLLMWorker,app
+ from vllm import AsyncLLMEngine
+ from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
+ args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
+ args.tokenizer_mode = 'auto'
+ args.trust_remote_code= True
+ args.download_dir= None
+ args.load_format = 'auto'
+ args.dtype = 'auto'
+ args.seed = 0
+ args.worker_use_ray = False
+ args.pipeline_parallel_size = 1
+ args.tensor_parallel_size = 1
+ args.block_size = 16
+ args.swap_space = 4 # GiB
+ args.gpu_memory_utilization = 0.90
+ args.max_num_batched_tokens = 2560
+ args.max_num_seqs = 256
+ args.disable_log_stats = False
+ args.conv_template = None
+ args.limit_worker_concurrency = 5
+ args.no_register = False
+ args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
+ args.engine_use_ray = False
+ args.disable_log_requests = False
+ if args.model_path:
+ args.model = args.model_path
+ if args.num_gpus > 1:
+ args.tensor_parallel_size = args.num_gpus
+
+ for k, v in kwargs.items():
+ setattr(args, k, v)
+
+ engine_args = AsyncEngineArgs.from_cli_args(args)
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
+
+ worker = VLLMWorker(
+ controller_addr = args.controller_address,
+ worker_addr = args.worker_address,
+ worker_id = worker_id,
+ model_path = args.model_path,
+ model_names = args.model_names,
+ limit_worker_concurrency = args.limit_worker_concurrency,
+ no_register = args.no_register,
+ llm_engine = engine,
+ conv_template = args.conv_template,
+ )
+ sys.modules["fastchat.serve.vllm_worker"].engine = engine
+ sys.modules["fastchat.serve.vllm_worker"].worker = worker
+
+ else:
+ from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker
+ args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
+ args.max_gpu_memory = "20GiB"
+ args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
+
+ args.load_8bit = False
+ args.cpu_offloading = None
+ args.gptq_ckpt = None
+ args.gptq_wbits = 16
+ args.gptq_groupsize = -1
+ args.gptq_act_order = False
+ args.awq_ckpt = None
+ args.awq_wbits = 16
+ args.awq_groupsize = -1
+ args.model_names = []
+ args.conv_template = None
+ args.limit_worker_concurrency = 5
+ args.stream_interval = 2
+ args.no_register = False
+ args.embed_in_truncate = False
+ for k, v in kwargs.items():
+ setattr(args, k, v)
+ if args.gpus:
+ if args.num_gpus is None:
+ args.num_gpus = len(args.gpus.split(','))
+ if len(args.gpus.split(",")) < args.num_gpus:
+ raise ValueError(
+ f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
+ )
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
+ gptq_config = GptqConfig(
+ ckpt=args.gptq_ckpt or args.model_path,
+ wbits=args.gptq_wbits,
+ groupsize=args.gptq_groupsize,
+ act_order=args.gptq_act_order,
+ )
+ awq_config = AWQConfig(
+ ckpt=args.awq_ckpt or args.model_path,
+ wbits=args.awq_wbits,
+ groupsize=args.awq_groupsize,
)
- self.heart_beat_thread.start()
- ModelWorker.init_heart_beat = _new_init_heart_beat
+ worker = ModelWorker(
+ controller_addr=args.controller_address,
+ worker_addr=args.worker_address,
+ worker_id=worker_id,
+ model_path=args.model_path,
+ model_names=args.model_names,
+ limit_worker_concurrency=args.limit_worker_concurrency,
+ no_register=args.no_register,
+ device=args.device,
+ num_gpus=args.num_gpus,
+ max_gpu_memory=args.max_gpu_memory,
+ load_8bit=args.load_8bit,
+ cpu_offloading=args.cpu_offloading,
+ gptq_config=gptq_config,
+ awq_config=awq_config,
+ stream_interval=args.stream_interval,
+ conv_template=args.conv_template,
+ embed_in_truncate=args.embed_in_truncate,
+ )
+ sys.modules["fastchat.serve.model_worker"].args = args
+ sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
- gptq_config = GptqConfig(
- ckpt=args.gptq_ckpt or args.model_path,
- wbits=args.gptq_wbits,
- groupsize=args.gptq_groupsize,
- act_order=args.gptq_act_order,
- )
- awq_config = AWQConfig(
- ckpt=args.awq_ckpt or args.model_path,
- wbits=args.awq_wbits,
- groupsize=args.awq_groupsize,
- )
-
- worker = ModelWorker(
- controller_addr=args.controller_address,
- worker_addr=args.worker_address,
- worker_id=worker_id,
- model_path=args.model_path,
- model_names=args.model_names,
- limit_worker_concurrency=args.limit_worker_concurrency,
- no_register=args.no_register,
- device=args.device,
- num_gpus=args.num_gpus,
- max_gpu_memory=args.max_gpu_memory,
- load_8bit=args.load_8bit,
- cpu_offloading=args.cpu_offloading,
- gptq_config=gptq_config,
- awq_config=awq_config,
- stream_interval=args.stream_interval,
- conv_template=args.conv_template,
- embed_in_truncate=args.embed_in_truncate,
- )
- sys.modules["fastchat.serve.model_worker"].args = args
- sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
-
- sys.modules["fastchat.serve.model_worker"].worker = worker
+ sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({args.model_names[0]})"
@@ -194,7 +253,6 @@ def create_openai_api_app(
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup")
async def on_startup():
- set_httpx_timeout()
if started_event is not None:
started_event.set()
@@ -205,6 +263,8 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
from fastapi import Body
import time
import sys
+ from server.utils import set_httpx_config
+ set_httpx_config()
app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
@@ -216,7 +276,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
- # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
+ # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
@@ -242,15 +302,16 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
logger.error(msg)
return {"code": 500, "msg": msg}
- r = httpx.post(worker_address + "/release",
- json={"new_model_name": new_model_name, "keep_origin": keep_origin})
- if r.status_code != 200:
- msg = f"failed to release model: {model_name}"
- logger.error(msg)
- return {"code": 500, "msg": msg}
+ with get_httpx_client() as client:
+ r = client.post(worker_address + "/release",
+ json={"new_model_name": new_model_name, "keep_origin": keep_origin})
+ if r.status_code != 200:
+ msg = f"failed to release model: {model_name}"
+ logger.error(msg)
+ return {"code": 500, "msg": msg}
if new_model_name:
- timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
+ timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
@@ -290,6 +351,8 @@ def run_model_worker(
import uvicorn
from fastapi import Body
import sys
+ from server.utils import set_httpx_config
+ set_httpx_config()
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
@@ -297,7 +360,7 @@ def run_model_worker(
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address(model_name)
- model_path = kwargs.get("local_model_path", "")
+ model_path = kwargs.get("model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
@@ -328,6 +391,8 @@ def run_model_worker(
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import sys
+ from server.utils import set_httpx_config
+ set_httpx_config()
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
@@ -344,6 +409,8 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
def run_api_server(started_event: mp.Event = None):
from server.api import create_app
import uvicorn
+ from server.utils import set_httpx_config
+ set_httpx_config()
app = create_app()
_set_app_event(app, started_event)
@@ -355,6 +422,9 @@ def run_api_server(started_event: mp.Event = None):
def run_webui(started_event: mp.Event = None):
+ from server.utils import set_httpx_config
+ set_httpx_config()
+
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
@@ -418,7 +488,7 @@ def parse_args() -> argparse.ArgumentParser:
"-c",
"--controller",
type=str,
- help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER",
+ help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
dest="controller_address",
)
parser.add_argument(
@@ -470,19 +540,18 @@ def dump_server_info(after_start=False, args=None):
if args and args.model_name:
models = args.model_name
- print(f"当前使用的分词器:{TEXT_SPLITTER}")
+ print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
for model in models:
- pprint(llm_model_dict[model])
+ pprint(get_model_worker_config(model))
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
- print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
- print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
+ print(f" OpenAI API Server: {fschat_openai_api_address()}")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:
diff --git a/tests/agent/test_agent_function.py b/tests/agent/test_agent_function.py
new file mode 100644
index 00000000..e860cb7a
--- /dev/null
+++ b/tests/agent/test_agent_function.py
@@ -0,0 +1,40 @@
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+from configs import LLM_MODEL, TEMPERATURE
+from server.utils import get_ChatOpenAI
+from langchain.chains import LLMChain
+from langchain.agents import LLMSingleActionAgent, AgentExecutor
+from server.agent.tools import tools, tool_names
+from langchain.memory import ConversationBufferWindowMemory
+
+memory = ConversationBufferWindowMemory(k=5)
+model = get_ChatOpenAI(
+ model_name=LLM_MODEL,
+ temperature=TEMPERATURE,
+ )
+from server.agent.custom_template import CustomOutputParser, prompt
+
+output_parser = CustomOutputParser()
+llm_chain = LLMChain(llm=model, prompt=prompt)
+agent = LLMSingleActionAgent(
+ llm_chain=llm_chain,
+ output_parser=output_parser,
+ stop=["\nObservation:"],
+ allowed_tools=tool_names
+)
+
+agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, memory=memory, verbose=True)
+
+import pytest
+@pytest.mark.parametrize("text_prompt",
+ ["北京市朝阳区未来24小时天气如何?", # 天气功能函数
+ "计算 (2 + 2312312)/4 是多少?", # 计算功能函数
+ "翻译这句话成中文:Life is the art of drawing sufficient conclusions form insufficient premises."] # 翻译功能函数
+)
+def test_different_agent_function(text_prompt):
+ try:
+ text_answer = agent_executor.run(text_prompt)
+ assert text_answer is not None
+ except Exception as e:
+ pytest.fail(f"agent_function failed with {text_prompt}, error: {str(e)}")
diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py
index ed4e8b21..975f8bcc 100644
--- a/tests/api/test_kb_api.py
+++ b/tests/api/test_kb_api.py
@@ -6,7 +6,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
-from configs.model_config import VECTOR_SEARCH_TOP_K
+from configs import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from pprint import pprint
diff --git a/tests/api/test_kb_api_request.py b/tests/api/test_kb_api_request.py
index 86455282..3c115f1e 100644
--- a/tests/api/test_kb_api_request.py
+++ b/tests/api/test_kb_api_request.py
@@ -6,7 +6,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
-from configs.model_config import VECTOR_SEARCH_TOP_K
+from configs import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest
diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py
index af5ced8f..89579818 100644
--- a/tests/api/test_llm_api.py
+++ b/tests/api/test_llm_api.py
@@ -6,21 +6,19 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import FSCHAT_MODEL_WORKERS
-from configs.model_config import LLM_MODEL, llm_model_dict
+from configs.model_config import LLM_MODEL
from server.utils import api_address, get_model_worker_config
from pprint import pprint
import random
+from typing import List
-def get_configured_models():
+def get_configured_models() -> List[str]:
model_workers = list(FSCHAT_MODEL_WORKERS)
if "default" in model_workers:
model_workers.remove("default")
-
- llm_dict = list(llm_model_dict)
-
- return model_workers, llm_dict
+ return model_workers
api_base_url = api_address()
@@ -56,12 +54,9 @@ def test_change_model(api="/llm_model/change"):
running_models = get_running_models()
assert len(running_models) > 0
- model_workers, llm_dict = get_configured_models()
+ model_workers = get_configured_models()
- availabel_new_models = set(model_workers) - set(running_models)
- if len(availabel_new_models) == 0:
- availabel_new_models = set(llm_dict) - set(running_models)
- availabel_new_models = list(availabel_new_models)
+ availabel_new_models = list(set(model_workers) - set(running_models))
assert len(availabel_new_models) > 0
print(availabel_new_models)
diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py
index 14314853..8b98c20d 100644
--- a/tests/api/test_stream_chat_api.py
+++ b/tests/api/test_stream_chat_api.py
@@ -4,7 +4,7 @@ import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
-from configs.model_config import BING_SUBSCRIPTION_KEY
+from configs import BING_SUBSCRIPTION_KEY
from server.utils import api_address
from pprint import pprint
@@ -91,7 +91,7 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
print("=" * 30 + api + " output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
- if "anser" in data:
+ if "answer" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
@@ -114,7 +114,7 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
print("\n")
- print("=" * 30 + api + " by {se} output" + "="*30)
+ print("=" * 30 + api + f" by {se} output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if "answer" in data:
diff --git a/tests/custom_splitter/test_different_splitter.py b/tests/custom_splitter/test_different_splitter.py
index fea597e7..2111bae1 100644
--- a/tests/custom_splitter/test_different_splitter.py
+++ b/tests/custom_splitter/test_different_splitter.py
@@ -4,7 +4,7 @@ from transformers import AutoTokenizer
import sys
sys.path.append("../..")
-from configs.model_config import (
+from configs import (
CHUNK_SIZE,
OVERLAP_SIZE
)
diff --git a/tests/online_api/test_fangzhou.py b/tests/online_api/test_fangzhou.py
new file mode 100644
index 00000000..1157537c
--- /dev/null
+++ b/tests/online_api/test_fangzhou.py
@@ -0,0 +1,22 @@
+import sys
+from pathlib import Path
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+
+from server.model_workers.fangzhou import request_volc_api
+from pprint import pprint
+import pytest
+
+
+@pytest.mark.parametrize("version", ["chatglm-6b-model"])
+def test_qianfan(version):
+ messages = [{"role": "user", "content": "hello"}]
+ print("\n" + version + "\n")
+ i = 1
+ for x in request_volc_api(messages, version=version):
+ print(type(x))
+ pprint(x)
+ if chunk := x.choice.message.content:
+ print(chunk)
+ assert x.choice.message
+ i += 1
diff --git a/tests/online_api/test_qianfan.py b/tests/online_api/test_qianfan.py
index 0e8a9487..b4b9b153 100644
--- a/tests/online_api/test_qianfan.py
+++ b/tests/online_api/test_qianfan.py
@@ -8,7 +8,7 @@ from pprint import pprint
import pytest
-@pytest.mark.parametrize("version", MODEL_VERSIONS.keys())
+@pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2])
def test_qianfan(version):
messages = [{"role": "user", "content": "你好"}]
print("\n" + version + "\n")
diff --git a/tests/online_api/test_qwen.py b/tests/online_api/test_qwen.py
new file mode 100644
index 00000000..001cf606
--- /dev/null
+++ b/tests/online_api/test_qwen.py
@@ -0,0 +1,19 @@
+import sys
+from pathlib import Path
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+
+from server.model_workers.qwen import request_qwen_api
+from pprint import pprint
+import pytest
+
+
+@pytest.mark.parametrize("version", ["qwen-turbo"])
+def test_qwen(version):
+ messages = [{"role": "user", "content": "hello"}]
+ print("\n" + version + "\n")
+
+ for x in request_qwen_api(messages, version=version):
+ print(type(x))
+ pprint(x)
+ assert x["code"] == 200
diff --git a/tests/test_migrate.py b/tests/test_migrate.py
new file mode 100644
index 00000000..d694b026
--- /dev/null
+++ b/tests/test_migrate.py
@@ -0,0 +1,139 @@
+from pathlib import Path
+from pprint import pprint
+import os
+import shutil
+import sys
+root_path = Path(__file__).parent.parent
+sys.path.append(str(root_path))
+
+from server.knowledge_base.kb_service.base import KBServiceFactory
+from server.knowledge_base.utils import get_kb_path, get_doc_path, KnowledgeFile
+from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder_files
+
+
+# setup test knowledge base
+kb_name = "test_kb_for_migrate"
+test_files = {
+ "faq.md": str(root_path / "docs" / "faq.md"),
+ "install.md": str(root_path / "docs" / "install.md"),
+}
+
+
+kb_path = get_kb_path(kb_name)
+doc_path = get_doc_path(kb_name)
+
+if not os.path.isdir(doc_path):
+ os.makedirs(doc_path)
+
+for k, v in test_files.items():
+ shutil.copy(v, os.path.join(doc_path, k))
+
+
+def test_recreate_vs():
+ folder2db([kb_name], "recreate_vs")
+
+ kb = KBServiceFactory.get_service_by_name(kb_name)
+ assert kb.exists()
+
+ files = kb.list_files()
+ print(files)
+ for name in test_files:
+ assert name in files
+ path = os.path.join(doc_path, name)
+
+ # list docs based on file name
+ docs = kb.list_docs(file_name=name)
+ assert len(docs) > 0
+ pprint(docs[0])
+ for doc in docs:
+ assert doc.metadata["source"] == path
+
+ # list docs base on metadata
+ docs = kb.list_docs(metadata={"source": path})
+ assert len(docs) > 0
+
+ for doc in docs:
+ assert doc.metadata["source"] == path
+
+
+def test_increament():
+ kb = KBServiceFactory.get_service_by_name(kb_name)
+ kb.clear_vs()
+ assert kb.list_files() == []
+ assert kb.list_docs() == []
+
+ folder2db([kb_name], "increament")
+
+ files = kb.list_files()
+ print(files)
+ for f in test_files:
+ assert f in files
+
+ docs = kb.list_docs(file_name=f)
+ assert len(docs) > 0
+ pprint(docs[0])
+
+ for doc in docs:
+ assert doc.metadata["source"] == os.path.join(doc_path, f)
+
+
+def test_prune_db():
+ del_file, keep_file = list(test_files)[:2]
+ os.remove(os.path.join(doc_path, del_file))
+
+ prune_db_docs([kb_name])
+
+ kb = KBServiceFactory.get_service_by_name(kb_name)
+ files = kb.list_files()
+ print(files)
+ assert del_file not in files
+ assert keep_file in files
+
+ docs = kb.list_docs(file_name=del_file)
+ assert len(docs) == 0
+
+ docs = kb.list_docs(file_name=keep_file)
+ assert len(docs) > 0
+ pprint(docs[0])
+
+ shutil.copy(test_files[del_file], os.path.join(doc_path, del_file))
+
+
+def test_prune_folder():
+ del_file, keep_file = list(test_files)[:2]
+ kb = KBServiceFactory.get_service_by_name(kb_name)
+
+ # delete docs for file
+ kb.delete_doc(KnowledgeFile(del_file, kb_name))
+ files = kb.list_files()
+ print(files)
+ assert del_file not in files
+ assert keep_file in files
+
+ docs = kb.list_docs(file_name=del_file)
+ assert len(docs) == 0
+
+ docs = kb.list_docs(file_name=keep_file)
+ assert len(docs) > 0
+
+ docs = kb.list_docs(file_name=del_file)
+ assert len(docs) == 0
+
+ assert os.path.isfile(os.path.join(doc_path, del_file))
+
+ # prune folder
+ prune_folder_files([kb_name])
+
+ # check result
+ assert not os.path.isfile(os.path.join(doc_path, del_file))
+ assert os.path.isfile(os.path.join(doc_path, keep_file))
+
+
+def test_drop_kb():
+ kb = KBServiceFactory.get_service_by_name(kb_name)
+ kb.drop_kb()
+ assert not kb.exists()
+ assert not os.path.isdir(kb_path)
+
+ kb = KBServiceFactory.get_service_by_name(kb_name)
+ assert kb is None
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index f581a9b5..05c2a4f7 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -1,15 +1,13 @@
import streamlit as st
-from configs.server_config import FSCHAT_MODEL_WORKERS
from webui_pages.utils import *
from streamlit_chatbox import *
from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
import os
-from configs.model_config import LLM_MODEL, TEMPERATURE
+from configs import LLM_MODEL, TEMPERATURE
from server.utils import get_model_worker_config
from typing import List, Dict
-
chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
@@ -18,30 +16,24 @@ chat_box = ChatBox(
)
-def get_messages_history(history_len: int) -> List[Dict]:
+def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
+ '''
+ 返回消息历史。
+ content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要
+ '''
+
def filter(msg):
- '''
- 针对当前简单文本对话,只返回每条消息的第一个element的内容
- '''
- content = [x._content for x in msg["elements"] if x._output_method in ["markdown", "text"]]
+ content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
+ if not content_in_expander:
+ content = [x for x in content if not x._in_expander]
+ content = [x.content for x in content]
+
return {
"role": msg["role"],
- "content": content[0] if content else "",
+ "content": "\n\n".join(content),
}
- # workaround before upgrading streamlit-chatbox.
- def stop(h):
- return False
-
- history = chat_box.filter_history(history_len=100000, filter=filter, stop=stop)
- user_count = 0
- i = 1
- for i in range(1, len(history) + 1):
- if history[-i]["role"] == "user":
- user_count += 1
- if user_count >= history_len:
- break
- return history[-i:]
+ return chat_box.filter_history(history_len=history_len, filter=filter)
def dialogue_page(api: ApiRequest):
@@ -63,6 +55,7 @@ def dialogue_page(api: ApiRequest):
["LLM 对话",
"知识库问答",
"搜索引擎问答",
+ "自定义Agent问答",
],
index=1,
on_change=on_mode_change,
@@ -71,8 +64,9 @@ def dialogue_page(api: ApiRequest):
def on_llm_change():
config = get_model_worker_config(llm_model)
- if not config.get("online_api"): # 只有本地model_worker可以切换模型
+ if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model
+ st.session_state["cur_llm_model"] = st.session_state.llm_model
def llm_model_format_func(x):
if x in running_models:
@@ -80,25 +74,32 @@ def dialogue_page(api: ApiRequest):
return x
running_models = api.list_running_models()
+ available_models = []
config_models = api.list_config_models()
- for x in running_models:
- if x in config_models:
- config_models.remove(x)
- llm_models = running_models + config_models
- cur_model = st.session_state.get("cur_llm_model", LLM_MODEL)
- index = llm_models.index(cur_model)
+ for models in config_models.values():
+ for m in models:
+ if m not in running_models:
+ available_models.append(m)
+ llm_models = running_models + available_models
+ index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL))
llm_model = st.selectbox("选择LLM模型:",
- llm_models,
- index,
- format_func=llm_model_format_func,
- on_change=on_llm_change,
- # key="llm_model",
- )
+ llm_models,
+ index,
+ format_func=llm_model_format_func,
+ on_change=on_llm_change,
+ key="llm_model",
+ )
if (st.session_state.get("prev_llm_model") != llm_model
- and not get_model_worker_config(llm_model).get("online_api")):
+ and not get_model_worker_config(llm_model).get("online_api")
+ and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
- r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
- st.session_state["cur_llm_model"] = llm_model
+ prev_model = st.session_state.get("prev_llm_model")
+ r = api.change_llm_model(prev_model, llm_model)
+ if msg := check_error_msg(r):
+ st.error(msg)
+ elif msg := check_success_msg(r):
+ st.success(msg)
+ st.session_state["prev_llm_model"] = llm_model
temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05)
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
@@ -143,17 +144,42 @@ def dialogue_page(api: ApiRequest):
text = ""
r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature)
for t in r:
- if error_msg := check_error_msg(t): # check whether error occured
+ if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
break
text += t
chat_box.update_msg(text)
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
+
+
+ elif dialogue_mode == "自定义Agent问答":
+ chat_box.ai_say([
+ f"正在思考和寻找工具 ...",])
+ text = ""
+ element_index = 0
+ for d in api.agent_chat(prompt,
+ history=history,
+ model=llm_model,
+ temperature=temperature):
+ try:
+ d = json.loads(d)
+ except:
+ pass
+ if error_msg := check_error_msg(d): # check whether error occured
+ st.error(error_msg)
+
+ elif chunk := d.get("answer"):
+ text += chunk
+ chat_box.update_msg(text, element_index=0)
+ elif chunk := d.get("tools"):
+ element_index += 1
+ chat_box.insert_msg(Markdown("...", in_expander=True, title="使用工具...", state="complete"))
+ chat_box.update_msg("\n\n".join(d.get("tools", [])), element_index=element_index, streaming=False)
+ chat_box.update_msg(text, element_index=0, streaming=False)
elif dialogue_mode == "知识库问答":
- history = get_messages_history(history_len)
chat_box.ai_say([
f"正在查询知识库 `{selected_kb}` ...",
- Markdown("...", in_expander=True, title="知识库匹配结果"),
+ Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
])
text = ""
for d in api.knowledge_base_chat(prompt,
@@ -173,12 +199,13 @@ def dialogue_page(api: ApiRequest):
elif dialogue_mode == "搜索引擎问答":
chat_box.ai_say([
f"正在执行 `{search_engine}` 搜索...",
- Markdown("...", in_expander=True, title="网络搜索结果"),
+ Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
])
text = ""
for d in api.search_engine_chat(prompt,
search_engine_name=search_engine,
top_k=se_top_k,
+ history=history,
model=llm_model,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py
index c71da7e4..bf8f0894 100644
--- a/webui_pages/knowledge_base/knowledge_base.py
+++ b/webui_pages/knowledge_base/knowledge_base.py
@@ -6,9 +6,10 @@ import pandas as pd
from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple
-from configs.model_config import (embedding_model_dict, kbs_config,
- EMBEDDING_MODEL, DEFAULT_VS_TYPE,
- CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
+from configs import (kbs_config,
+ EMBEDDING_MODEL, DEFAULT_VS_TYPE,
+ CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
+from server.utils import list_embed_models
import os
import time
@@ -94,7 +95,7 @@ def knowledge_base_page(api: ApiRequest):
key="vs_type",
)
- embed_models = list(embedding_model_dict.keys())
+ embed_models = list_embed_models()
embed_model = cols[1].selectbox(
"Embedding 模型",
diff --git a/webui_pages/utils.py b/webui_pages/utils.py
index 26e53206..a2113685 100644
--- a/webui_pages/utils.py
+++ b/webui_pages/utils.py
@@ -1,12 +1,11 @@
# 该文件包含webui通用工具,可以被不同的webui使用
from typing import *
from pathlib import Path
-from configs.model_config import (
+from configs import (
EMBEDDING_MODEL,
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
- llm_model_dict,
HISTORY_LEN,
TEMPERATURE,
SCORE_THRESHOLD,
@@ -15,9 +14,10 @@ from configs.model_config import (
ZH_TITLE_ENHANCE,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
+ FSCHAT_MODEL_WORKERS,
+ HTTPX_DEFAULT_TIMEOUT,
logger, log_verbose,
)
-from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx
import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn
@@ -26,7 +26,7 @@ import contextlib
import json
import os
from io import BytesIO
-from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
+from server.utils import run_async, iter_over_async, set_httpx_config, api_address, get_httpx_client
from configs.model_config import NLTK_DATA_PATH
import nltk
@@ -35,7 +35,7 @@ from pprint import pprint
KB_ROOT_PATH = Path(KB_ROOT_PATH)
-set_httpx_timeout()
+set_httpx_config()
class ApiRequest:
@@ -53,6 +53,8 @@ class ApiRequest:
self.base_url = base_url
self.timeout = timeout
self.no_remote_api = no_remote_api
+ self._client = get_httpx_client()
+ self._aclient = get_httpx_client(use_async=True)
if no_remote_api:
logger.warn("将来可能取消对no_remote_api的支持,更新版本时请注意。")
@@ -79,9 +81,9 @@ class ApiRequest:
while retry > 0:
try:
if stream:
- return httpx.stream("GET", url, params=params, **kwargs)
+ return self._client.stream("GET", url, params=params, **kwargs)
else:
- return httpx.get(url, params=params, **kwargs)
+ return self._client.get(url, params=params, **kwargs)
except Exception as e:
msg = f"error when get {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
@@ -98,18 +100,18 @@ class ApiRequest:
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
- async with httpx.AsyncClient() as client:
- while retry > 0:
- try:
- if stream:
- return await client.stream("GET", url, params=params, **kwargs)
- else:
- return await client.get(url, params=params, **kwargs)
- except Exception as e:
- msg = f"error when aget {url}: {e}"
- logger.error(f'{e.__class__.__name__}: {msg}',
- exc_info=e if log_verbose else None)
- retry -= 1
+
+ while retry > 0:
+ try:
+ if stream:
+ return await self._aclient.stream("GET", url, params=params, **kwargs)
+ else:
+ return await self._aclient.get(url, params=params, **kwargs)
+ except Exception as e:
+ msg = f"error when aget {url}: {e}"
+ logger.error(f'{e.__class__.__name__}: {msg}',
+ exc_info=e if log_verbose else None)
+ retry -= 1
def post(
self,
@@ -124,11 +126,10 @@ class ApiRequest:
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
- # return requests.post(url, data=data, json=json, stream=stream, **kwargs)
if stream:
- return httpx.stream("POST", url, data=data, json=json, **kwargs)
+ return self._client.stream("POST", url, data=data, json=json, **kwargs)
else:
- return httpx.post(url, data=data, json=json, **kwargs)
+ return self._client.post(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when post {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
@@ -146,18 +147,18 @@ class ApiRequest:
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
- async with httpx.AsyncClient() as client:
- while retry > 0:
- try:
- if stream:
- return await client.stream("POST", url, data=data, json=json, **kwargs)
- else:
- return await client.post(url, data=data, json=json, **kwargs)
- except Exception as e:
- msg = f"error when apost {url}: {e}"
- logger.error(f'{e.__class__.__name__}: {msg}',
- exc_info=e if log_verbose else None)
- retry -= 1
+
+ while retry > 0:
+ try:
+ if stream:
+ return await self._client.stream("POST", url, data=data, json=json, **kwargs)
+ else:
+ return await self._client.post(url, data=data, json=json, **kwargs)
+ except Exception as e:
+ msg = f"error when apost {url}: {e}"
+ logger.error(f'{e.__class__.__name__}: {msg}',
+ exc_info=e if log_verbose else None)
+ retry -= 1
def delete(
self,
@@ -173,9 +174,9 @@ class ApiRequest:
while retry > 0:
try:
if stream:
- return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
+ return self._client.stream("DELETE", url, data=data, json=json, **kwargs)
else:
- return httpx.delete(url, data=data, json=json, **kwargs)
+ return self._client.delete(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when delete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
@@ -193,18 +194,18 @@ class ApiRequest:
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
- async with httpx.AsyncClient() as client:
- while retry > 0:
- try:
- if stream:
- return await client.stream("DELETE", url, data=data, json=json, **kwargs)
- else:
- return await client.delete(url, data=data, json=json, **kwargs)
- except Exception as e:
- msg = f"error when adelete {url}: {e}"
- logger.error(f'{e.__class__.__name__}: {msg}',
- exc_info=e if log_verbose else None)
- retry -= 1
+
+ while retry > 0:
+ try:
+ if stream:
+ return await self._aclient.stream("DELETE", url, data=data, json=json, **kwargs)
+ else:
+ return await self._aclient.delete(url, data=data, json=json, **kwargs)
+ except Exception as e:
+ msg = f"error when adelete {url}: {e}"
+ logger.error(f'{e.__class__.__name__}: {msg}',
+ exc_info=e if log_verbose else None)
+ retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
'''
@@ -315,6 +316,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
+ prompt_name: str = "llm_chat",
no_remote_api: bool = None,
):
'''
@@ -323,6 +325,41 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
+ data = {
+ "query": query,
+ "history": history,
+ "stream": stream,
+ "model_name": model,
+ "temperature": temperature,
+ "prompt_name": prompt_name,
+ }
+
+ print(f"received input message:")
+ pprint(data)
+
+ if no_remote_api:
+ from server.chat.chat import chat
+ response = run_async(chat(**data))
+ return self._fastapi_stream2generator(response)
+ else:
+ response = self.post("/chat/chat", json=data, stream=True)
+ return self._httpx_stream2generator(response)
+
+ def agent_chat(
+ self,
+ query: str,
+ history: List[Dict] = [],
+ stream: bool = True,
+ model: str = LLM_MODEL,
+ temperature: float = TEMPERATURE,
+ no_remote_api: bool = None,
+ ):
+ '''
+ 对应api.py/chat/agent_chat 接口
+ '''
+ if no_remote_api is None:
+ no_remote_api = self.no_remote_api
+
data = {
"query": query,
"history": history,
@@ -335,11 +372,11 @@ class ApiRequest:
pprint(data)
if no_remote_api:
- from server.chat.chat import chat
- response = run_async(chat(**data))
+ from server.chat.agent_chat import agent_chat
+ response = run_async(agent_chat(**data))
return self._fastapi_stream2generator(response)
else:
- response = self.post("/chat/chat", json=data, stream=True)
+ response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response)
def knowledge_base_chat(
@@ -352,6 +389,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
+ prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None,
):
'''
@@ -370,6 +408,7 @@ class ApiRequest:
"model_name": model,
"temperature": temperature,
"local_doc_url": no_remote_api,
+ "prompt_name": prompt_name,
}
print(f"received input message:")
@@ -392,9 +431,11 @@ class ApiRequest:
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
+ history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
+ prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None,
):
'''
@@ -407,9 +448,11 @@ class ApiRequest:
"query": query,
"search_engine_name": search_engine_name,
"top_k": top_k,
+ "history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
+ "prompt_name": prompt_name,
}
print(f"received input message:")
@@ -766,20 +809,31 @@ class ApiRequest:
"controller_address": controller_address,
}
if no_remote_api:
- from server.llm_api import list_llm_models
- return list_llm_models(**data).data
+ from server.llm_api import list_running_models
+ return list_running_models(**data).data
else:
r = self.post(
- "/llm_model/list_models",
+ "/llm_model/list_running_models",
json=data,
)
return r.json().get("data", [])
- def list_config_models(self):
+ def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]:
'''
- 获取configs中配置的模型列表
+ 获取configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。
+ 如果no_remote_api=True, 从运行ApiRequest的机器上获取;否则从运行api.py的机器上获取。
'''
- return list(llm_model_dict.keys())
+ if no_remote_api is None:
+ no_remote_api = self.no_remote_api
+
+ if no_remote_api:
+ from server.llm_api import list_config_models
+ return list_config_models().data
+ else:
+ r = self.post(
+ "/llm_model/list_config_models",
+ )
+ return r.json().get("data", {})
def stop_llm_model(
self,
@@ -825,13 +879,13 @@ class ApiRequest:
if not model_name or not new_model_name:
return
- if new_model_name == model_name:
+ running_models = self.list_running_models()
+ if new_model_name == model_name or new_model_name in running_models:
return {
"code": 200,
- "msg": "什么都不用做"
+ "msg": "无需切换"
}
- running_models = self.list_running_models()
if model_name not in running_models:
return {
"code": 500,
@@ -839,7 +893,7 @@ class ApiRequest:
}
config_models = self.list_config_models()
- if new_model_name not in config_models:
+ if new_model_name not in config_models.get("local", []):
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"