release 0.2.6 (#1815)
## 🛠 新增功能 - 支持百川在线模型 (@hzg0601 @liunux4odoo in #1623) - 支持 Azure OpenAI 与 claude 等 Langchain 自带模型 (@zRzRzRzRzRzRzR in #1808) - Agent 功能大量更新,支持更多的工具、更换提示词、检索知识库 (@zRzRzRzRzRzRzR in #1626 #1666 #1785) - 加长 32k 模型的历史记录 (@zRzRzRzRzRzRzR in #1629 #1630) - *_chat 接口支持 max_tokens 参数 (@liunux4odoo in #1744) - 实现 API 和 WebUI 的前后端分离 (@liunux4odoo in #1772) - 支持 zlilliz 向量库 (@zRzRzRzRzRzRzR in #1785) - 支持 metaphor 搜索引擎 (@liunux4odoo in #1792) - 支持 p-tuning 模型 (@hzg0601 in #1810) - 更新完善文档和 Wiki (@imClumsyPanda @zRzRzRzRzRzRzR @glide-the in #1680 #1811) ## 🐞 问题修复 - 修复 bge-* 模型匹配超过 1 的问题 (@zRzRzRzRzRzRzR in #1652) - 修复系统代理为空的问题 (@glide-the in #1654) - 修复重建知识库时 `d == self.d assert error` (@liunux4odoo in #1766) - 修复对话历史消息错误 (@liunux4odoo in #1801) - 修复 OpenAI 无法调用的 bug (@zRzRzRzRzRzRzR in #1808) - 修复 windows下 BIND_HOST=0.0.0.0 时对话出错的问题 (@hzg0601 in #1810)
@ -1,22 +0,0 @@
|
||||
# 贡献指南
|
||||
|
||||
欢迎!我们是一个非常友好的社区,非常高兴您想要帮助我们让这个应用程序变得更好。但是,请您遵循一些通用准则以保持组织有序。
|
||||
|
||||
1. 确保为您要修复的错误或要添加的功能创建了一个[问题](https://github.com/imClumsyPanda/langchain-ChatGLM/issues),尽可能保持它们小。
|
||||
2. 请使用 `git pull --rebase` 来拉取和衍合上游的更新。
|
||||
3. 将提交合并为格式良好的提交。在提交说明中单独一行提到要解决的问题,如`Fix #<bug>`(有关更多可以使用的关键字,请参见[将拉取请求链接到问题](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue))。
|
||||
4. 推送到`dev`。在说明中提到正在解决的问题。
|
||||
|
||||
---
|
||||
|
||||
# Contribution Guide
|
||||
|
||||
Welcome! We're a pretty friendly community, and we're thrilled that you want to help make this app even better. However, we ask that you follow some general guidelines to keep things organized around here.
|
||||
|
||||
1. Make sure an [issue](https://github.com/imClumsyPanda/langchain-ChatGLM/issues) is created for the bug you're about to fix, or feature you're about to add. Keep them as small as possible.
|
||||
|
||||
2. Please use `git pull --rebase` to fetch and merge updates from the upstream.
|
||||
|
||||
3. Rebase commits into well-formatted commits. Mention the issue being resolved in the commit message on a line all by itself like `Fixes #<bug>` (refer to [Linking a pull request to an issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) for more keywords you can use).
|
||||
|
||||
4. Push into `dev`. Mention which bug is being resolved in the description.
|
||||
499
README.md
@ -1,33 +1,27 @@
|
||||

|
||||
|
||||
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
||||
|
||||
🌍 [READ THIS IN ENGLISH](README_en.md)
|
||||
|
||||
📃 **LangChain-Chatchat** (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。
|
||||
📃 **LangChain-Chatchat** (原 Langchain-ChatGLM)
|
||||
|
||||
基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
* [介绍](README.md#介绍)
|
||||
* [环境最低要求](README.md#环境最低要求)
|
||||
* [变更日志](README.md#变更日志)
|
||||
* [模型支持](README.md#模型支持)
|
||||
* [Agent 生态](README.md#Agent-生态)
|
||||
* [Docker 部署](README.md#Docker-部署)
|
||||
* [开发部署](README.md#开发部署)
|
||||
* [软件需求](README.md#软件需求)
|
||||
* [1. 开发环境准备](README.md#1-开发环境准备)
|
||||
* [2. 下载模型至本地](README.md#2-下载模型至本地)
|
||||
* [3. 设置配置项](README.md#3-设置配置项)
|
||||
* [4. 知识库初始化与迁移](README.md#4-知识库初始化与迁移)
|
||||
* [5. 一键启动 API 服务或 Web UI](README.md#5-一键启动-API-服务或-Web-UI)
|
||||
* [常见问题](README.md#常见问题)
|
||||
* [最佳实践](README.md#最佳实践)
|
||||
* [项目 Wiki](README.md#项目-Wiki)
|
||||
* [路线图](README.md#路线图)
|
||||
* [项目交流群](README.md#项目交流群)
|
||||
* [解决的痛点](README.md#解决的痛点)
|
||||
* [快速上手](README.md#快速上手)
|
||||
* [1. 环境配置](README.md#1-环境配置)
|
||||
* [2. 模型下载](README.md#2-模型下载)
|
||||
* [3. 初始化知识库和配置文件](README.md#3-初始化知识库和配置文件)
|
||||
* [4. 一键启动](README.md#4-一键启动)
|
||||
* [5. 启动界面示例](README.md#5-启动界面示例)
|
||||
* [联系我们](README.md#联系我们)
|
||||
* [合作伙伴名单](README.md#合作伙伴名单)
|
||||
|
||||
---
|
||||
|
||||
## 介绍
|
||||
|
||||
@ -51,234 +45,45 @@
|
||||
|
||||
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v9` 版本所使用代码已更新至本项目 `v0.2.5` 版本。
|
||||
|
||||
🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)
|
||||
🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3) 已经更新到 ```0.2.3``` 版本, 如果想体验最新内容请源码安装。
|
||||
|
||||
💻 一行命令运行 Docker 🌲:
|
||||
🧩 本项目有一个非常完整的[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) , README只是一个简单的介绍,__仅仅是入门教程,能够基础运行__。 如果你想要更深入的了解本项目,或者对相对本项目做出共享。请移步 [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) 界面
|
||||
|
||||
```shell
|
||||
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3
|
||||
## 解决的痛点
|
||||
|
||||
该项目是一个可以实现 __完全本地化__推理的知识库增强方案, 重点解决数数据安全保护,私域化部署的企业痛点。
|
||||
本开源方案采用```Apache License``,可以免费商用,无需付费。
|
||||
|
||||
我们支持市面上主流的本地大预言模型和Embedding模型,支持开源的本地向量数据库。
|
||||
支持列表详见[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
|
||||
|
||||
|
||||
## 快速上手
|
||||
|
||||
### 1. 环境配置
|
||||
|
||||
+ 首先,确保你的机器安装了 Python 3.10
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 环境最低要求
|
||||
|
||||
想顺利运行本项目代码,请按照以下的最低要求进行配置:
|
||||
+ Python 版本: >= 3.8.5, < 3.11
|
||||
+ CUDA 版本: >= 11.7
|
||||
+ 强烈推荐使用 Python 3.10,部分 Agent 功能可能没有完全支持 Python 3.10 以下版本。
|
||||
|
||||
如果想要顺利在 GPU 运行本地模型(int4 版本),你至少需要以下的硬件配置:
|
||||
|
||||
+ ChatGLM2-6B & LLaMA-7B
|
||||
+ 最低显存要求: 7GB
|
||||
+ 推荐显卡: RTX 3060, RTX 2060
|
||||
+ LLaMA-13B
|
||||
+ 最低显存要求: 11GB
|
||||
+ 推荐显卡: RTX 2060 12GB, RTX 3060 12GB, RTX 3080, RTX A2000
|
||||
+ Qwen-14B-Chat
|
||||
+ 最低显存要求: 13GB
|
||||
+ 推荐显卡: RTX 3090
|
||||
+ LLaMA-30B
|
||||
+ 最低显存要求: 22GB
|
||||
+ 推荐显卡: RTX A5000, RTX 3090, RTX 4090, RTX 6000, Tesla V100, RTX Tesla P40
|
||||
+ LLaMA-65B
|
||||
+ 最低显存要求: 40GB
|
||||
+ 推荐显卡: A100, A40, A6000
|
||||
|
||||
若使用 int8 推理,则显存大致为 int4 推理要求的 1.5 倍;
|
||||
|
||||
若使用 fp16 推理,则显存大致为 int4 推理要求的 2.5 倍。
|
||||
|
||||
💡 例如:使用 fp16 推理 Qwen-7B-Chat 模型,则需要使用 16GB 显存。
|
||||
|
||||
以上仅为估算,实际情况以 nvidia-smi 占用为准。
|
||||
|
||||
---
|
||||
|
||||
## 变更日志
|
||||
|
||||
参见 [版本更新日志](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)。
|
||||
|
||||
从 `0.1.x` 升级过来的用户请注意,需要按照[开发部署](README.md#3-开发部署)过程操作,将现有知识库迁移到新格式,具体见[知识库初始化与迁移](docs/INSTALL.md#知识库初始化与迁移)。
|
||||
|
||||
### `0.2.0` 版本与 `0.1.x` 版本区别
|
||||
|
||||
1. 使用 [FastChat](https://github.com/lm-sys/FastChat) 提供开源 LLM 模型的 API,以 OpenAI API 接口形式接入,提升 LLM 模型加载效果;
|
||||
2. 使用 [langchain](https://github.com/langchain-ai/langchain) 中已有 Chain 的实现,便于后续接入不同类型 Chain,并将对 Agent 接入开展测试;
|
||||
3. 使用 [FastAPI](https://github.com/tiangolo/fastapi) 提供 API 服务,全部接口可在 FastAPI 自动生成的 docs 中开展测试,且所有对话接口支持通过参数设置流式或非流式输出;
|
||||
4. 使用 [Streamlit](https://github.com/streamlit/streamlit) 提供 WebUI 服务,可选是否基于 API 服务启动 WebUI,增加会话管理,可以自定义会话主题并切换,且后续可支持不同形式输出内容的显示;
|
||||
5. 项目中默认 LLM 模型改为 [THUDM/ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b),默认 Embedding 模型改为 [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base),文件加载方式与文段划分方式也有调整,后续将重新实现上下文扩充,并增加可选设置;
|
||||
6. 项目中扩充了对不同类型向量库的支持,除支持 [FAISS](https://github.com/facebookresearch/faiss) 向量库外,还提供 [Milvus](https://github.com/milvus-io/milvus), [PGVector](https://github.com/pgvector/pgvector) 向量库的接入;
|
||||
7. 项目中搜索引擎对话,除 Bing 搜索外,增加 DuckDuckGo 搜索选项,DuckDuckGo 搜索无需配置 API Key,在可访问国外服务环境下可直接使用。
|
||||
|
||||
---
|
||||
|
||||
## 模型支持
|
||||
|
||||
本项目中默认使用的 LLM 模型为 [THUDM/ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b),默认使用的 Embedding 模型为 [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base) 为例。
|
||||
|
||||
### LLM 模型支持
|
||||
|
||||
本项目最新版本中支持接入**本地模型**与**在线 LLM API**。
|
||||
|
||||
本地 LLM 模型接入基于 [FastChat](https://github.com/lm-sys/FastChat) 实现,支持模型如下:
|
||||
|
||||
- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
- Vicuna, Alpaca, LLaMA, Koala
|
||||
- [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven)
|
||||
- [camel-ai/CAMEL-13B-Combined-Data](https://huggingface.co/camel-ai/CAMEL-13B-Combined-Data)
|
||||
- [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b)
|
||||
- [FreedomIntelligence/phoenix-inst-chat-7b](https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b)
|
||||
- [h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b)
|
||||
- [lcw99/polyglot-ko-12.8b-chang-instruct-chat](https://huggingface.co/lcw99/polyglot-ko-12.8b-chang-instruct-chat)
|
||||
- [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5)
|
||||
- [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat)
|
||||
- [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT)
|
||||
- [nomic-ai/gpt4all-13b-snoozy](https://huggingface.co/nomic-ai/gpt4all-13b-snoozy)
|
||||
- [NousResearch/Nous-Hermes-13b](https://huggingface.co/NousResearch/Nous-Hermes-13b)
|
||||
- [openaccess-ai-collective/manticore-13b-chat-pyg](https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg)
|
||||
- [OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5](https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5)
|
||||
- [project-baize/baize-v2-7b](https://huggingface.co/project-baize/baize-v2-7b)
|
||||
- [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b)
|
||||
- [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b)
|
||||
- [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
||||
- [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
|
||||
- [tiiuae/falcon-40b](https://huggingface.co/tiiuae/falcon-40b)
|
||||
- [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged)
|
||||
- [togethercomputer/RedPajama-INCITE-7B-Chat](https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat)
|
||||
- [WizardLM/WizardLM-13B-V1.0](https://huggingface.co/WizardLM/WizardLM-13B-V1.0)
|
||||
- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
|
||||
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
|
||||
- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
|
||||
- [Qwen/Qwen-7B-Chat/Qwen-14B-Chat](https://huggingface.co/Qwen/)
|
||||
- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta)
|
||||
- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and others
|
||||
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
|
||||
- [all models of OpenOrca](https://huggingface.co/Open-Orca)
|
||||
- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2)
|
||||
- [VMware's OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
|
||||
- [baichuan2-7b/baichuan2-13b](https://huggingface.co/baichuan-inc)
|
||||
- 任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)
|
||||
- 在以上模型基础上训练的任何 [Peft](https://github.com/huggingface/peft) 适配器。为了激活,模型路径中必须有 `peft` 。注意:如果加载多个peft模型,你可以通过在任何模型工作器中设置环境变量 `PEFT_SHARE_BASE_WEIGHTS=true` 来使它们共享基础模型的权重。
|
||||
|
||||
以上模型支持列表可能随 [FastChat](https://github.com/lm-sys/FastChat) 更新而持续更新,可参考 [FastChat 已支持模型列表](https://github.com/lm-sys/FastChat/blob/main/docs/model_support.md)。
|
||||
|
||||
除本地模型外,本项目也支持直接接入 OpenAI API、智谱AI等在线模型,具体设置可参考 `configs/model_configs.py.example` 中的 `ONLINE_LLM_MODEL` 的配置信息。
|
||||
|
||||
在线 LLM 模型目前已支持:
|
||||
|
||||
- [ChatGPT](https://api.openai.com)
|
||||
- [智谱AI](http://open.bigmodel.cn)
|
||||
- [MiniMax](https://api.minimax.chat)
|
||||
- [讯飞星火](https://xinghuo.xfyun.cn)
|
||||
- [百度千帆](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
|
||||
- [字节火山方舟](https://www.volcengine.com/docs/82379)
|
||||
- [阿里云通义千问](https://dashscope.aliyun.com/)
|
||||
- [百川](https://www.baichuan-ai.com/home#api-enter) (个人用户 API_KEY 暂未开放)
|
||||
|
||||
项目中默认使用的 LLM 类型为 `THUDM/ChatGLM2-6B`,如需使用其他 LLM 类型,请在 `configs/model_config.py` 中对 `MODEL_PATH` 和 `LLM_MODEL` 进行修改。
|
||||
|
||||
### Embedding 模型支持
|
||||
|
||||
本项目支持调用 [HuggingFace](https://huggingface.co/models?pipeline_tag=sentence-similarity) 中的 Embedding 模型,已支持的 Embedding 模型如下:
|
||||
|
||||
- [moka-ai/m3e-small](https://huggingface.co/moka-ai/m3e-small)
|
||||
- [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base)
|
||||
- [moka-ai/m3e-large](https://huggingface.co/moka-ai/m3e-large)
|
||||
- [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh)
|
||||
- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)
|
||||
- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)
|
||||
- [BAAI/bge-base-zh-v1.5](https://huggingface.co/BAAI/bge-base-zh-v1.5)
|
||||
- [BAAI/bge-large-zh-v1.5](https://huggingface.co/BAAI/bge-large-zh-v1.5)- [BAAI/bge-base-zh-v1.5](https://huggingface.co/BAAI/bge-base-zh-v1.5)
|
||||
- [BAAI/bge-large-zh-v1.5](https://huggingface.co/BAAI/bge-large-zh-v1.5)
|
||||
- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct)
|
||||
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
|
||||
- [sensenova/piccolo-large-zh](https://huggingface.co/sensenova/piccolo-large-zh)
|
||||
- [shibing624/text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence)
|
||||
- [shibing624/text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase)
|
||||
- [shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
|
||||
- [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese)
|
||||
- [shibing624/text2vec-bge-large-chinese](https://huggingface.co/shibing624/text2vec-bge-large-chinese)
|
||||
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
|
||||
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
|
||||
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
|
||||
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
|
||||
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-large-zh)
|
||||
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
|
||||
|
||||
项目中默认使用的 Embedding 类型为 `moka-ai/m3e-base`,如需使用其他 Embedding 类型,请在 `configs/model_config.py` 中对 `embedding_model_dict` 和 `EMBEDDING_MODEL` 进行修改。
|
||||
|
||||
### Text Splitter 个性化支持
|
||||
|
||||
本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的 Text Splitter 分词器以及基于此改进的自定义分词器,已支持的 Text Splitter 类型如下:
|
||||
|
||||
- CharacterTextSplitter
|
||||
- LatexTextSplitter
|
||||
- MarkdownHeaderTextSplitter
|
||||
- MarkdownTextSplitter
|
||||
- NLTKTextSplitter
|
||||
- PythonCodeTextSplitter
|
||||
- RecursiveCharacterTextSplitter
|
||||
- SentenceTransformersTokenTextSplitter
|
||||
- SpacyTextSplitter
|
||||
|
||||
已经支持的定制分词器如下:
|
||||
|
||||
- [AliTextSplitter](text_splitter/ali_text_splitter.py)
|
||||
- [ChineseRecursiveTextSplitter](text_splitter/chinese_recursive_text_splitter.py)
|
||||
- [ChineseTextSplitter](text_splitter/chinese_text_splitter.py)
|
||||
|
||||
项目中默认使用的 Text Splitter 类型为 `ChineseRecursiveTextSplitter`,如需使用其他 Text Splitter 类型,请在 `configs/model_config.py` 中对 `text_splitter_dict` 和 `TEXT_SPLITTER` 进行修改。
|
||||
|
||||
关于如何使用自定义分词器和贡献自己的分词器,可以参考 [如何自定义分词器](docs/splitter.md)。
|
||||
|
||||
---
|
||||
|
||||
## Agent 生态
|
||||
### 基础的 Agent
|
||||
在本版本中,我们实现了一个简单的基于 OpenAI 的 ReAct 的 Agent 模型,目前,经过我们测试,仅有以下两个模型支持:
|
||||
+ OpenAI GPT4
|
||||
+ ChatGLM2-130B
|
||||
|
||||
目前版本的 Agent 仍然需要对提示词进行大量调试。
|
||||
|
||||
### 构建自己的 Agent 工具
|
||||
|
||||
详见 [自定义 Agent 说明](docs/自定义Agent.md)
|
||||
|
||||
---
|
||||
|
||||
## Docker 部署
|
||||
|
||||
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
|
||||
|
||||
```shell
|
||||
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3
|
||||
$ python --version
|
||||
Python 3.10.12
|
||||
```
|
||||
接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖
|
||||
```shell
|
||||
|
||||
- 该版本镜像大小 `35.3GB`,使用 `v0.2.5`,以 `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` 为基础镜像
|
||||
- 该版本内置两个 `embedding` 模型:`m3e-large`,`text2vec-bge-large-chinese`,默认启用后者,内置 `chatglm2-6b-32k`
|
||||
- 该版本目标为方便一键部署使用,请确保您已经在Linux发行版上安装了NVIDIA驱动程序
|
||||
- 请注意,您不需要在主机系统上安装CUDA工具包,但需要安装 `NVIDIA Driver` 以及 `NVIDIA Container Toolkit`,请参考[安装指南](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
|
||||
- 首次拉取和启动均需要一定时间,首次启动时请参照下图使用 `docker logs -f <container id>` 查看日志
|
||||
- 如遇到启动过程卡在 `Waiting..` 步骤,建议使用 `docker exec -it <container id> bash` 进入 `/logs/` 目录查看对应阶段日志
|
||||
# 拉取仓库
|
||||
$ git clone https://github.com/chatchat-space/Langchain-Chatchat.git
|
||||
|
||||
---
|
||||
# 进入目录
|
||||
$ cd Langchain-Chatchat
|
||||
|
||||
## 开发部署
|
||||
# 安装全部依赖
|
||||
$ pip install -r requirements.txt
|
||||
$ pip install -r requirements_api.txt
|
||||
$ pip install -r requirements_webui.txt
|
||||
|
||||
### 软件需求
|
||||
|
||||
本项目已在 Python 3.8.1 - 3.10,CUDA 11.7 环境下完成测试。已在 Windows、ARM 架构的 macOS、Linux 系统中完成测试。
|
||||
|
||||
### 1. 开发环境准备
|
||||
|
||||
参见 [开发环境准备](docs/INSTALL.md)。
|
||||
|
||||
**请注意:** `0.2.5` 及更新版本的依赖包与 `0.1.x` 版本依赖包可能发生冲突,强烈建议新建环境后重新安装依赖包。
|
||||
|
||||
### 2. 下载模型至本地
|
||||
# 默认依赖包括基本运行环境(FAISS向量库)。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
|
||||
```
|
||||
### 2, 模型下载
|
||||
|
||||
如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding 模型可以从 [HuggingFace](https://huggingface.co/models) 下载。
|
||||
|
||||
@ -287,215 +92,57 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
||||
下载模型需要先[安装 Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage),然后运行
|
||||
|
||||
```Shell
|
||||
$ git lfs install
|
||||
$ git clone https://huggingface.co/THUDM/chatglm2-6b
|
||||
|
||||
$ git clone https://huggingface.co/moka-ai/m3e-base
|
||||
```
|
||||
### 3. 初始化知识库和配置文件
|
||||
|
||||
### 3. 设置配置项
|
||||
|
||||
复制相关参数配置模板文件 `configs/*_config.py.example`,存储至项目路径下 `./configs` 路径下,并重命名为 `*_config.py`。
|
||||
|
||||
在开始执行 Web UI 或命令行交互前,请先检查 `configs/model_config.py` 和 `configs/server_config.py` 中的各项模型参数设计是否符合需求:
|
||||
|
||||
- 请确认已下载至本地的 LLM 模型本地存储路径(请使用绝对路径)写在 `MODEL_PATH` 对应模型位置,如:
|
||||
|
||||
```
|
||||
"chatglm2-6b": "/Users/xxx/Downloads/chatglm2-6b",
|
||||
```
|
||||
|
||||
- 请确认已下载至本地的 Embedding 模型本地存储路径写在 `MODEL_PATH` 对应模型位置,如:
|
||||
|
||||
```
|
||||
"m3e-base": "/Users/xxx/Downloads/m3e-base",
|
||||
```
|
||||
|
||||
- 请确认本地分词器路径是否已经填写,如:
|
||||
|
||||
```
|
||||
text_splitter_dict = {
|
||||
"ChineseRecursiveTextSplitter": {
|
||||
"source": "huggingface",
|
||||
## 选择tiktoken则使用openai的方法,不填写则默认为字符长度切割方法。
|
||||
"tokenizer_name_or_path": "",
|
||||
## 空格不填则默认使用大模型的分词器。
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
如果你选择使用 OpenAI 的 Embedding 模型,请将模型的 `key` 写入 `ONLINE_LLM_MODEL` 中。使用该模型,你需要能够访问 OpenAI 官方的 API,或设置代理。
|
||||
|
||||
### 4. 知识库初始化与迁移
|
||||
|
||||
当前项目的知识库信息存储在数据库中,在正式运行项目之前请先初始化数据库(我们强烈建议您在执行操作前备份您的知识文件)。
|
||||
|
||||
- 如果您是从 `0.1.x` 版本升级过来的用户,针对已建立的知识库,请确认知识库的向量库类型、Embedding 模型与 `configs/model_config.py` 中默认设置一致,如无变化只需以下命令将现有知识库信息添加到数据库即可:
|
||||
|
||||
```shell
|
||||
$ python init_database.py
|
||||
```
|
||||
|
||||
- 如果您是第一次运行本项目,知识库尚未建立,或者配置文件中的知识库类型、嵌入模型发生变化,或者之前的向量库没有开启 `normalize_L2`,需要以下命令初始化或重建知识库:
|
||||
|
||||
```shell
|
||||
$ python init_database.py --recreate-vs
|
||||
```
|
||||
|
||||
### 5. 一键启动 API 服务或 Web UI
|
||||
|
||||
#### 5.1 启动命令
|
||||
|
||||
一键启动脚本 `startup.py`,一键启动所有 FastChat 服务、API 服务、WebUI 服务,示例代码:
|
||||
按照下列方式初始化自己的知识库和简单的复制配置文件
|
||||
```shell
|
||||
$ python copy_config_example.py
|
||||
$ python init_database.py --recreate-vs
|
||||
```
|
||||
### 4. 一键启动
|
||||
|
||||
按照以下命令启动项目
|
||||
```shell
|
||||
$ python startup.py -a
|
||||
```
|
||||
### 5. 启动界面示例
|
||||
|
||||
并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。
|
||||
|
||||
可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`,
|
||||
`-m (或--model-worker)`, `--api`, `--webui`,其中:
|
||||
|
||||
- `--all-webui` 为一键启动 WebUI 所有依赖服务;
|
||||
- `--all-api` 为一键启动 API 所有依赖服务;
|
||||
- `--llm-api` 为一键启动 FastChat 所有依赖的 LLM 服务;
|
||||
- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务;
|
||||
- 其他为单独服务启动选项。
|
||||
|
||||
更多信息可以通过 `python startup.py -h` 查看
|
||||
|
||||
#### 5.2 启动非默认模型
|
||||
|
||||
若想指定非默认模型,需要用 `--model-name` 选项,示例:
|
||||
|
||||
```shell
|
||||
$ python startup.py --all-webui --model-name Qwen-7B-Chat
|
||||
```
|
||||
|
||||
请注意,指定的模型必须在 `model_config.py` 中进行了配置。
|
||||
|
||||
#### 5.3 多卡加载
|
||||
|
||||
项目支持多卡加载,需在 `startup.py` 中的 `create_model_worker_app` 函数中,修改如下三个参数:
|
||||
|
||||
```python
|
||||
gpus = None,
|
||||
num_gpus = 1,
|
||||
max_gpu_memory = "20GiB"
|
||||
```
|
||||
|
||||
其中,`gpus` 控制使用的显卡的 ID,例如 "0,1";
|
||||
|
||||
`num_gpus` 控制使用的卡数;
|
||||
|
||||
`max_gpu_memory` 控制每个卡使用的显存容量。
|
||||
|
||||
注1:`server_config.py` 的 `FSCHAT_MODEL_WORKERS` 字典中也增加了相关配置,如有需要也可通过修改 `FSCHAT_MODEL_WORKERS` 字典中对应参数实现多卡加载。
|
||||
|
||||
注2:少数情况下,`gpus` 参数会不生效,此时需要通过设置环境变量 `CUDA_VISIBLE_DEVICES` 来指定 torch 可见的 GPU,示例代码:
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
|
||||
```
|
||||
|
||||
#### 5.4 PEFT 加载(包括 lora, p-tuning, prefix tuning, ia3等)
|
||||
|
||||
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 `adapter_config.json`,peft 路径下包含 .bin 格式的 PEFT 权重,peft 路径在 `startup.py` 中 `create_model_worker_app` 函数的 `args.model_names` 中指定,并开启环境变量 `PEFT_SHARE_BASE_WEIGHTS=true` 参数。
|
||||
|
||||
注:如果上述方式启动失败,则需要以标准的 FastChat 服务启动方式分步启动。PEFT 加载详细步骤参考 [加载 LoRA 微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)
|
||||
|
||||
#### 5.5 注意事项
|
||||
|
||||
1. `startup.py` 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:20000`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP:8501`)。
|
||||
|
||||
2. 服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。
|
||||
|
||||
3. 在 Linux 上使用 `Ctrl+C` 退出可能会由于 Linux 的多进程机制导致 multiprocessing 遗留孤儿进程,可通过运行 `shutdown_all.sh` 进行退出
|
||||
|
||||
#### 5.6 启动界面示例:
|
||||
如果正常启动,你将能看到以下界面
|
||||
|
||||
1. FastAPI Docs 界面
|
||||
|
||||

|
||||

|
||||
|
||||
2. Web UI 启动界面示例:
|
||||
|
||||
- Web UI 对话界面:
|
||||
|
||||

|
||||

|
||||
|
||||
- Web UI 知识库管理页面:
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
### 注意
|
||||
|
||||
以上方式只是为了快速上手,如果需要更多的功能和自定义启动方式 ,请参考[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
|
||||
|
||||
|
||||
---
|
||||
## 联系我们
|
||||
### Telegram
|
||||
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
||||
|
||||
## 常见问题
|
||||
|
||||
参见 [常见问题](docs/FAQ.md)。
|
||||
|
||||
---
|
||||
|
||||
## 最佳实践
|
||||
|
||||
请参见 [最佳实践](https://github.com/chatchat-space/Langchain-Chatchat/wiki/最佳实践)
|
||||
|
||||
---
|
||||
## 项目 Wiki
|
||||
|
||||
更多项目相关开发介绍、参数配置等信息,请参见 [项目 Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki)
|
||||
|
||||
---
|
||||
|
||||
## 路线图
|
||||
|
||||
- [X] Langchain 应用
|
||||
- [X] 本地数据接入
|
||||
- [X] 接入非结构化文档
|
||||
- [X] .md
|
||||
- [X] .txt
|
||||
- [X] .docx
|
||||
- [ ] 结构化数据接入
|
||||
- [X] .csv
|
||||
- [ ] .xlsx
|
||||
- [ ] 分词及召回
|
||||
- [X] 接入不同类型 TextSplitter
|
||||
- [X] 优化依据中文标点符号设计的 ChineseTextSplitter
|
||||
- [ ] 重新实现上下文拼接召回
|
||||
- [ ] 本地网页接入
|
||||
- [ ] SQL 接入
|
||||
- [ ] 知识图谱/图数据库接入
|
||||
- [X] 搜索引擎接入
|
||||
- [X] Bing 搜索
|
||||
- [X] DuckDuckGo 搜索
|
||||
- [X] Agent 实现
|
||||
- [X] 基础React形式的Agent实现,包括调用计算器等
|
||||
- [X] Langchain 自带的Agent实现和调用
|
||||
- [ ] 更多模型的Agent支持
|
||||
- [ ] 更多工具
|
||||
- [X] LLM 模型接入
|
||||
- [X] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm
|
||||
- [X] 支持 ChatGLM API 等 LLM API 的接入
|
||||
- [X] Embedding 模型接入
|
||||
- [X] 支持调用 HuggingFace 中各开源 Emebdding 模型
|
||||
- [X] 支持 OpenAI Embedding API 等 Embedding API 的接入
|
||||
- [X] 基于 FastAPI 的 API 方式调用
|
||||
- [X] Web UI
|
||||
- [X] 基于 Streamlit 的 Web UI
|
||||
|
||||
---
|
||||
|
||||
## 项目交流群
|
||||
|
||||
### 项目交流群
|
||||
<img src="img/qr_code_67.jpg" alt="二维码" width="300" height="300" />
|
||||
|
||||
🎉 langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||
|
||||
|
||||
## 关注我们
|
||||
|
||||

|
||||
|
||||
🎉 langchain-Chatchat 项目官方公众号,欢迎扫码关注。
|
||||
|
||||
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||
### 公众号
|
||||
<img src="img/official_wechat_mp_account.png" alt="图片" width="900" height="300" />
|
||||
🎉 Langchain-Chatchat 项目官方公众号,欢迎扫码关注。
|
||||
|
||||
|
||||
450
README_en.md
@ -1,38 +1,49 @@
|
||||

|
||||
|
||||
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
||||
|
||||
🌍 [中文文档](README.md)
|
||||
|
||||
📃 **LangChain-Chatchat** (formerly Langchain-ChatGLM): A LLM application aims to implement knowledge and search engine based QA based on Langchain and open-source or remote LLM API.
|
||||
📃 **LangChain-Chatchat** (formerly Langchain-ChatGLM):
|
||||
|
||||
## Content
|
||||
|
||||
* [Introduction](README_en.md#Introduction)
|
||||
* [Change Log](README_en.md#Change-Log)
|
||||
* [Supported Models](README_en.md#Supported-Models)
|
||||
* [Docker Deployment](README_en.md#Docker-Deployment)
|
||||
* [Development](README_en.md#Development)
|
||||
* [Environment Prerequisite](README_en.md#Environment-Prerequisite)
|
||||
* [Preparing Deployment Environment](README_en.md#1.-Preparing-Deployment-Environment)
|
||||
* [Downloading model to local disk](README_en.md#2.-Downloading-model-to-local-disk)
|
||||
* [Setting Configuration](README_en.md#3.-Setting-Configuration)
|
||||
* [Knowledge Base Migration](README_en.md#4.-Knowledge-Base-Migration)
|
||||
* [Launching API Service or WebUI](README_en.md#5.-Launching-API-Service-or-WebUI-with-One-Command)
|
||||
* [FAQ](README_en.md#FAQ)
|
||||
* [Roadmap](README_en.md#Roadmap)
|
||||
A LLM application aims to implement knowledge and search engine based QA based on Langchain and open-source or remote
|
||||
LLM API.
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Introduction](README.md#Introduction)
|
||||
- [Pain Points Addressed](README.md#Pain-Points-Addressed)
|
||||
- [Quick Start](README.md#Quick-Start)
|
||||
- [1. Environment Setup](README.md#1-Environment-Setup)
|
||||
- [2. Model Download](README.md#2-Model-Download)
|
||||
- [3. Initialize Knowledge Base and Configuration Files](README.md#3-Initialize-Knowledge-Base-and-Configuration-Files)
|
||||
- [4. One-Click Startup](README.md#4-One-Click-Startup)
|
||||
- [5. Startup Interface Examples](README.md#5-Startup-Interface-Examples)
|
||||
- [Contact Us](README.md#Contact-Us)
|
||||
- [List of Partner Organizations](README.md#List-of-Partner-Organizations)
|
||||
|
||||
## Introduction
|
||||
|
||||
🤖️ A Q&A application based on local knowledge base implemented using the idea of [langchain](https://github.com/hwchase17/langchain). The goal is to build a KBQA(Knowledge based Q&A) solution that is friendly to Chinese scenarios and open source models and can run both offline and online.
|
||||
🤖️ A Q&A application based on local knowledge base implemented using the idea
|
||||
of [langchain](https://github.com/hwchase17/langchain). The goal is to build a KBQA(Knowledge based Q&A) solution that
|
||||
is friendly to Chinese scenarios and open source models and can run both offline and online.
|
||||
|
||||
💡 Inspried by [document.ai](https://github.com/GanymedeNil/document.ai) and [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) , we build a local knowledge base question answering application that can be implemented using an open source model or remote LLM api throughout the process. In the latest version of this project, [FastChat](https://github.com/lm-sys/FastChat) is used to access Vicuna, Alpaca, LLaMA, Koala, RWKV and many other models. Relying on [langchain](https:// github.com/langchain-ai/langchain) , this project supports calling services through the API provided based on [FastAPI](https://github.com/tiangolo/fastapi), or using the WebUI based on [Streamlit](https://github.com /streamlit/streamlit) .
|
||||
💡 Inspried by [document.ai](https://github.com/GanymedeNil/document.ai)
|
||||
and [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) , we build a local knowledge base question
|
||||
answering application that can be implemented using an open source model or remote LLM api throughout the process. In
|
||||
the latest version of this project, [FastChat](https://github.com/lm-sys/FastChat) is used to access Vicuna, Alpaca,
|
||||
LLaMA, Koala, RWKV and many other models. Relying on [langchain](https://github.com/langchain-ai/langchain) , this
|
||||
project supports calling services through the API provided based on [FastAPI](https://github.com/tiangolo/fastapi), or
|
||||
using the WebUI based on [Streamlit](https://github.com/streamlit/streamlit).
|
||||
|
||||
✅ Relying on the open source LLM and Embedding models, this project can realize full-process **offline private deployment**. At the same time, this project also supports the call of OpenAI GPT API- and Zhipu API, and will continue to expand the access to various models and remote APIs in the future.
|
||||
✅ Relying on the open source LLM and Embedding models, this project can realize full-process **offline private
|
||||
deployment**. At the same time, this project also supports the call of OpenAI GPT API- and Zhipu API, and will continue
|
||||
to expand the access to various models and remote APIs in the future.
|
||||
|
||||
⛓️ The implementation principle of this project is shown in the graph below. The main process includes: loading files -> reading text -> text segmentation -> text vectorization -> question vectorization -> matching the `top-k` most similar to the question vector in the text vector -> The matched text is added to `prompt `as context and question -> submitted to `LLM` to generate an answer.
|
||||
⛓️ The implementation principle of this project is shown in the graph below. The main process includes: loading files ->
|
||||
reading text -> text segmentation -> text vectorization -> question vectorization -> matching the `top-k` most similar
|
||||
to the question vector in the text vector -> The matched text is added to `prompt `as context and question -> submitted
|
||||
to `LLM` to generate an answer.
|
||||
|
||||
📺[video introdution](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514)
|
||||
|
||||
@ -42,374 +53,121 @@ The main process analysis from the aspect of document process:
|
||||
|
||||

|
||||
|
||||
🚩 The training or fined-tuning are not involved in the project, but still, one always can improve performance by do these.
|
||||
🚩 The training or fined-tuning are not involved in the project, but still, one always can improve performance by do
|
||||
these.
|
||||
|
||||
🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v9 the codes are update to v0.2.5.
|
||||
🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v9 the codes are update
|
||||
to v0.2.5.
|
||||
|
||||
🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5)
|
||||
|
||||
💻 Run Docker with one command:
|
||||
## Pain Points Addressed
|
||||
|
||||
```shell
|
||||
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5
|
||||
This project is a solution for enhancing knowledge bases with fully localized inference, specifically addressing the
|
||||
pain points of data security and private deployments for businesses.
|
||||
This open-source solution is under the Apache License and can be used for commercial purposes for free, with no fees
|
||||
required.
|
||||
We support mainstream local large prophecy models and Embedding models available in the market, as well as open-source
|
||||
local vector databases. For a detailed list of supported models and databases, please refer to
|
||||
our [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Environment Setup
|
||||
|
||||
First, make sure your machine has Python 3.10 installed.
|
||||
|
||||
```
|
||||
$ python --version
|
||||
Python 3.10.12
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Environment Minimum Requirements
|
||||
|
||||
To run this code smoothly, please configure it according to the following minimum requirements:
|
||||
+ Python version: >= 3.8.5, < 3.11
|
||||
+ Cuda version: >= 11.7
|
||||
+ Python 3.10 is highly recommended, some Agent features may not be fully supported below Python 3.10.
|
||||
|
||||
If you want to run the native model (int4 version) on the GPU without problems, you need at least the following hardware configuration.
|
||||
|
||||
+ chatglm2-6b & LLaMA-7B Minimum RAM requirement: 7GB Recommended graphics cards: RTX 3060, RTX 2060
|
||||
+ LLaMA-13B Minimum graphics memory requirement: 11GB Recommended cards: RTX 2060 12GB, RTX3060 12GB, RTX3080, RTXA2000
|
||||
+ Qwen-14B-Chat Minimum memory requirement: 13GB Recommended graphics card: RTX 3090
|
||||
|
||||
+ LLaMA-30B Minimum Memory Requirement: 22GB Recommended Cards: RTX A5000,RTX 3090,RTX 4090,RTX 6000,Tesla V100,RTX Tesla P40
|
||||
+ Minimum memory requirement for LLaMA-65B: 40GB Recommended cards: A100,A40,A6000
|
||||
If int8 then memory x1.5 fp16 x2.5 requirement.
|
||||
For example: using fp16 to reason about the Qwen-7B-Chat model requires 16GB of video memory.
|
||||
|
||||
The above is only an estimate, the actual situation is based on nvidia-smi occupancy.
|
||||
|
||||
## Change Log
|
||||
|
||||
plese refer to [version change log](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)
|
||||
|
||||
### Current Features
|
||||
|
||||
* **Consistent LLM Service based on FastChat**. The project use [FastChat](https://github.com/lm-sys/FastChat) to provide the API service of the open source LLM models and access it in the form of OpenAI API interface to improve the loading effect of the LLM model;
|
||||
* **Chain and Agent based on Langchian**. Use the existing Chain implementation in [langchain](https://github.com/langchain-ai/langchain) to facilitate subsequent access to different types of Chain, and will test Agent access;
|
||||
* **Full fuction API service based on FastAPI**. All interfaces can be tested in the docs automatically generated by [FastAPI](https://github.com/tiangolo/fastapi), and all dialogue interfaces support streaming or non-streaming output through parameters. ;
|
||||
* **WebUI service based on Streamlit**. With [Streamlit](https://github.com/streamlit/streamlit), you can choose whether to start WebUI based on API services, add session management, customize session themes and switch, and will support different display of content forms of output in the future;
|
||||
* **Abundant open source LLM and Embedding models**. The default LLM model in the project is changed to [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b), and the default Embedding model is changed to [moka-ai/m3e-base](https:// huggingface.co/moka-ai/m3e-base), the file loading method and the paragraph division method have also been adjusted. In the future, context expansion will be re-implemented and optional settings will be added;
|
||||
* **Multiply vector libraries**. The project has expanded support for different types of vector libraries. Including [FAISS](https://github.com/facebookresearch/faiss), [Milvus](https://github.com/milvus -io/milvus), and [PGVector](https://github.com/pgvector/pgvector);
|
||||
* **Varied Search engines**. We provide two search engines now: Bing and DuckDuckGo. DuckDuckGo search does not require configuring an API Key and can be used directly in environments with access to foreign services.
|
||||
|
||||
## Supported Models
|
||||
|
||||
The default LLM model in the project is changed to [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b), and the default Embedding model is changed to [moka-ai/m3e-base](https:// huggingface.co/moka-ai/m3e-base).
|
||||
|
||||
### Supported LLM models
|
||||
|
||||
The project use [FastChat](https://github.com/lm-sys/FastChat) to provide the API service of the open source LLM models, supported models include:
|
||||
|
||||
- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
- Vicuna, Alpaca, LLaMA, Koala
|
||||
- [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven)
|
||||
- [camel-ai/CAMEL-13B-Combined-Data](https://huggingface.co/camel-ai/CAMEL-13B-Combined-Data)
|
||||
- [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b)
|
||||
- [FreedomIntelligence/phoenix-inst-chat-7b](https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b)
|
||||
- [h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b)
|
||||
- [lcw99/polyglot-ko-12.8b-chang-instruct-chat](https://huggingface.co/lcw99/polyglot-ko-12.8b-chang-instruct-chat)
|
||||
- [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5)
|
||||
- [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat)
|
||||
- [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT)
|
||||
- [nomic-ai/gpt4all-13b-snoozy](https://huggingface.co/nomic-ai/gpt4all-13b-snoozy)
|
||||
- [NousResearch/Nous-Hermes-13b](https://huggingface.co/NousResearch/Nous-Hermes-13b)
|
||||
- [openaccess-ai-collective/manticore-13b-chat-pyg](https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg)
|
||||
- [OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5](https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5)
|
||||
- [project-baize/baize-v2-7b](https://huggingface.co/project-baize/baize-v2-7b)
|
||||
- [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b)
|
||||
- [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b)
|
||||
- [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
||||
- [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
|
||||
- [tiiuae/falcon-40b](https://huggingface.co/tiiuae/falcon-40b)
|
||||
- [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged)
|
||||
- [togethercomputer/RedPajama-INCITE-7B-Chat](https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat)
|
||||
- [WizardLM/WizardLM-13B-V1.0](https://huggingface.co/WizardLM/WizardLM-13B-V1.0)
|
||||
- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
|
||||
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
|
||||
- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
|
||||
- [Qwen/Qwen-7B-Chat/Qwen-14B-Chat](https://huggingface.co/Qwen/)
|
||||
- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta)
|
||||
- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and other models of FlagAlpha
|
||||
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
|
||||
- [all models of OpenOrca](https://huggingface.co/Open-Orca)
|
||||
- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2)
|
||||
- [baichuan2-7b/baichuan2-13b](https://huggingface.co/baichuan-inc)
|
||||
- [VMware's OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
|
||||
|
||||
* Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)(任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b))
|
||||
* Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a model above. To activate, must have `peft` in the model path. Note: If loading multiple peft models, you can have them share the base model weights by setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model worker.
|
||||
|
||||
|
||||
The above model support list may be updated continuously as [FastChat](https://github.com/lm-sys/FastChat) is updated, see [FastChat Supported Models List](https://github.com/lm-sys/FastChat/blob/main /docs/model_support.md).
|
||||
In addition to local models, this project also supports direct access to online models such as OpenAI API, Wisdom Spectrum AI, etc. For specific settings, please refer to the configuration information of `llm_model_dict` in `configs/model_configs.py.example`.
|
||||
Online LLM models are currently supported:
|
||||
|
||||
- [ChatGPT](https://api.openai.com)
|
||||
- [Smart Spectrum AI](http://open.bigmodel.cn)
|
||||
- [MiniMax](https://api.minimax.chat)
|
||||
- [Xunfei Starfire](https://xinghuo.xfyun.cn)
|
||||
- [Baidu Qianfan](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
|
||||
- [Aliyun Tongyi Qianqian](https://dashscope.aliyun.com/)
|
||||
|
||||
The default LLM type used in the project is `THUDM/chatglm2-6b`, if you need to use other LLM types, please modify `llm_model_dict` and `LLM_MODEL` in [configs/model_config.py].
|
||||
|
||||
### Supported Embedding models
|
||||
|
||||
Following models are tested by developers with Embedding class of [HuggingFace](https://huggingface.co/models?pipeline_tag=sentence-similarity):
|
||||
|
||||
- [moka-ai/m3e-small](https://huggingface.co/moka-ai/m3e-small)
|
||||
- [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base)
|
||||
- [moka-ai/m3e-large](https://huggingface.co/moka-ai/m3e-large)
|
||||
- [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh)
|
||||
- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)
|
||||
- [BAAI/bge-base-zh-v1.5](https://huggingface.co/BAAI/bge-base-zh-v1.5)
|
||||
- [BAAI/bge-large-zh-v1.5](https://huggingface.co/BAAI/bge-large-zh-v1.5)
|
||||
- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)
|
||||
- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct)
|
||||
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
|
||||
- [sensenova/piccolo-large-zh](https://huggingface.co/sensenova/piccolo-large-zh)
|
||||
- [shibing624/text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence)
|
||||
- [shibing624/text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase)
|
||||
- [shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
|
||||
- [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese)
|
||||
- [shibing624/text2vec-bge-large-chinese](https://huggingface.co/shibing624/text2vec-bge-large-chinese)
|
||||
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
|
||||
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
|
||||
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
|
||||
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
|
||||
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-large-zh)
|
||||
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
|
||||
|
||||
The default Embedding type used in the project is `sensenova/piccolo-base-zh`, if you want to use other Embedding types, please modify `embedding_model_dict` and `embedding_model_dict` and `embedding_model_dict` in [configs/model_config.py]. MODEL` in [configs/model_config.py].
|
||||
|
||||
### Build your own Agent tool!
|
||||
|
||||
See [Custom Agent Instructions](docs/自定义Agent.md) for details.
|
||||
|
||||
---
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
🐳 Docker image path: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5)`
|
||||
Then, create a virtual environment and install the project's dependencies within the virtual environment.
|
||||
|
||||
```shell
|
||||
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5
|
||||
|
||||
# 拉取仓库
|
||||
$ git clone https://github.com/chatchat-space/Langchain-Chatchat.git
|
||||
|
||||
# 进入目录
|
||||
$ cd Langchain-Chatchat
|
||||
|
||||
# 安装全部依赖
|
||||
$ pip install -r requirements.txt
|
||||
$ pip install -r requirements_api.txt
|
||||
$ pip install -r requirements_webui.txt
|
||||
|
||||
# 默认依赖包括基本运行环境(FAISS向量库)。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
|
||||
```
|
||||
|
||||
- The image size of this version is `33.9GB`, using `v0.2.5`, with `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` as the base image
|
||||
- This version has a built-in `embedding` model: `m3e-large`, built-in `chatglm2-6b-32k`
|
||||
- This version is designed to facilitate one-click deployment. Please make sure you have installed the NVIDIA driver on your Linux distribution.
|
||||
- Please note that you do not need to install the CUDA toolkit on the host system, but you need to install the `NVIDIA Driver` and the `NVIDIA Container Toolkit`, please refer to the [Installation Guide](https://docs.nvidia.com/datacenter/cloud -native/container-toolkit/latest/install-guide.html)
|
||||
- It takes a certain amount of time to pull and start for the first time. When starting for the first time, please refer to the figure below to use `docker logs -f <container id>` to view the log.
|
||||
- If the startup process is stuck in the `Waiting..` step, it is recommended to use `docker exec -it <container id> bash` to enter the `/logs/` directory to view the corresponding stage logs
|
||||
### Model Download
|
||||
|
||||
---
|
||||
If you need to run this project locally or in an offline environment, you must first download the required models for
|
||||
the project. Typically, open-source LLM and Embedding models can be downloaded from HuggingFace.
|
||||
|
||||
## Development
|
||||
Taking the default LLM model used in this project, [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b), and
|
||||
the Embedding model [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base) as examples:
|
||||
|
||||
### Environment Prerequisite
|
||||
|
||||
The project is tested under Python3.8-python 3.10, CUDA 11.0-CUDA11.7, Windows, macOS of ARM architecture, and Linux platform.
|
||||
|
||||
### 1. Preparing Deployment Environment
|
||||
|
||||
Please refer to [install.md](docs/INSTALL.md)
|
||||
|
||||
### 2. Downloading model to local disk
|
||||
|
||||
**For offline deployment only!**
|
||||
|
||||
If you want to run this project in a local or offline environment, you need to first download the models required for the project to your local computer. Usually the open source LLM and Embedding models can be downloaded from [HuggingFace](https://huggingface.co/models).
|
||||
|
||||
Take the LLM model [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) and Embedding model [moka-ai/m3e-base](https://huggingface. co/moka-ai/m3e-base) for example:
|
||||
|
||||
To download the model, you need to [install Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), and then run:
|
||||
To download the models, you need to first
|
||||
install [Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage)
|
||||
and then run:
|
||||
|
||||
```Shell
|
||||
$ git lfs install
|
||||
$ git clone https://huggingface.co/THUDM/chatglm2-6b
|
||||
|
||||
$ git clone https://huggingface.co/moka-ai/m3e-base
|
||||
```
|
||||
|
||||
### 3. Setting Configuration
|
||||
Copy the model-related parameter configuration template file [configs/model_config.py.example](configs/model_config.py.example) and store it under the project path `. /configs` path and rename it `model_config.py`.
|
||||
### Initializing the Knowledge Base and Config File
|
||||
|
||||
Copy the service-related parameter configuration template file [configs/server_config.py.example](configs/server_config.py.example) and store it under the project path `. /configs` path and rename it `server_config.py`.
|
||||
Follow the steps below to initialize your own knowledge base and config file:
|
||||
|
||||
Before you start executing the Web UI or command line interactions, check that each of the items in [configs/model_config.py](configs/model_config.py) and [configs/server_config.py](configs/server_config.py) The model parameters are designed to meet the requirements:
|
||||
```shell
|
||||
$ python copy_config_example.py
|
||||
$ python init_database.py --recreate-vs
|
||||
```
|
||||
|
||||
- Please make sure that the local storage path of the downloaded LLM model is written in the `local_model_path` attribute of the corresponding model in `llm_model_dict`, e.g..
|
||||
```
|
||||
"chatglm2-6b":"/Users/xxx/Downloads/chatglm2-6b",
|
||||
### One-Click Launch
|
||||
|
||||
```
|
||||
|
||||
- Please make sure that the local storage path of the downloaded Embedding model is written in `embedding_model_dict` corresponding to the model location, e.g.:
|
||||
|
||||
```
|
||||
"m3e-base":"/Users/xxx/Downloads/m3e-base", ``` Please make sure that the local storage path of the downloaded Embedding model is written in the location of the corresponding model, e.g.
|
||||
```
|
||||
|
||||
- Please make sure that the local participle path is filled in, e.g.:
|
||||
|
||||
```
|
||||
text_splitter_dict = {
|
||||
"ChineseRecursiveTextSplitter": {
|
||||
"source": "huggingface", ## Select tiktoken to use openai's method, don't fill it in then it defaults to character length cutting method.
|
||||
"tokenizer_name_or_path": "", ## Leave blank to use the big model of the tokeniser.
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Knowledge Base Migration
|
||||
|
||||
The knowledge base information is stored in the database, please initialize the database before running the project (we strongly recommend one back up the knowledge files before performing operations).
|
||||
|
||||
- If you migrate from `0.1.x`, for the established knowledge base, please confirm that the vector library type and Embedding model of the knowledge base are consistent with the default settings in `configs/model_config.py`, if there is no change, simply add the existing repository information to the database with the following command:
|
||||
|
||||
```shell
|
||||
$ python init_database.py
|
||||
```
|
||||
- If you are a beginner of the project whose knowledge base has not been established, or the knowledge base type and embedding model in the configuration file have changed, or the previous vector library did not enable `normalize_L2`, you need the following command to initialize or rebuild the knowledge base:
|
||||
|
||||
```shell
|
||||
$ python init_database.py --recreate-vs
|
||||
```
|
||||
|
||||
### 5. Launching API Service or WebUI with One Command
|
||||
|
||||
#### 5.1 Command
|
||||
|
||||
The script is `startuppy`, you can luanch all fastchat related, API,WebUI service with is, here is an example:
|
||||
To start the project, run the following command:
|
||||
|
||||
```shell
|
||||
$ python startup.py -a
|
||||
```
|
||||
|
||||
optional args including: `-a(or --all-webui), --all-api, --llm-api, -c(or --controller),--openai-api, -m(or --model-worker), --api, --webui`, where:
|
||||
### Example of Launch Interface
|
||||
|
||||
* `--all-webui` means to launch all related services of WEBUI
|
||||
* `--all-api` means to launch all related services of API
|
||||
* `--llm-api` means to launch all related services of FastChat
|
||||
* `--openai-api` means to launch controller and openai-api-server of FastChat only
|
||||
* `model-worker` means to launch model worker of FastChat only
|
||||
* any other optional arg is to launch one particular function only
|
||||
1. FastAPI docs interface
|
||||
|
||||
#### 5.2 Launch none-default model
|
||||

|
||||
|
||||
If you want to specify a none-default model, use `--model-name` arg, here is a example:
|
||||
2. webui page
|
||||
|
||||
```shell
|
||||
$ python startup.py --all-webui --model-name Qwen-7B-Chat
|
||||
```
|
||||
- Web UI dialog page:
|
||||
|
||||
#### 5.3 Load model with multi-gpus
|
||||

|
||||
|
||||
If you want to load model with multi-gpus, then the following three parameters in `startup.create_model_worker_app` should be changed:
|
||||
- Web UI knowledge base management page:
|
||||
|
||||
```python
|
||||
gpus=None,
|
||||
num_gpus=1,
|
||||
max_gpu_memory="20GiB"
|
||||
```
|
||||

|
||||
|
||||
where:
|
||||
### Note
|
||||
|
||||
* `gpus` is about specifying the gpus' ID, such as '0,1';
|
||||
* `num_gpus` is about specifying the number of gpus to be used under `gpus`;
|
||||
* `max_gpu_memory` is about specifying the gpu memory of every gpu.
|
||||
|
||||
note:
|
||||
|
||||
* These parameters now can be specified by `server_config.FSCHST_MODEL_WORKERD`.
|
||||
* In some extreme senses, `gpus` doesn't work, then one should specify the used gpus with environment variable `CUDA_VISIBLE_DEVICES`, here is an example:
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
|
||||
```
|
||||
|
||||
#### 5.4 Load PEFT
|
||||
|
||||
Including lora,p-tuning,prefix tuning, prompt tuning,ia3
|
||||
|
||||
This project loads the LLM service based on FastChat, so one must load the PEFT in a FastChat way, that is, ensure that the word `peft` must be in the path name, the name of the configuration file must be `adapter_config.json`, and the path contains PEFT weights in `.bin` format. The peft path is specified in `args.model_names` of the `create_model_worker_app` function in `startup.py`, and enable the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` parameter.
|
||||
|
||||
If the above method fails, you need to start standard fastchat service step by step. Step-by-step procedure could be found Section 6. For further steps, please refer to [Model invalid after loading lora fine-tuning](https://github. com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822).
|
||||
|
||||
#### **5.5 Some Notes**
|
||||
|
||||
1. **The `startup.py` uses multi-process mode to start the services of each module, which may cause printing order problems. Please wait for all services to be initiated before calling, and call the service according to the default or specified port (default LLM API service port: `127.0.0.1:8888 `, default API service port:`127.0.0.1:7861 `, default WebUI service port: `127.0.0.1: 8501`)**
|
||||
2. **The startup time of the service differs across devices, usually it takes 3-10 minutes. If it does not start for a long time, please go to the `./logs` directory to monitor the logs and locate the problem.**
|
||||
3. **Using ctrl+C to exit on Linux may cause orphan processes due to the multi-process mechanism of Linux. You can exit through `shutdown_all.sh`**
|
||||
|
||||
#### 5.6 Interface Examples
|
||||
|
||||
The API, chat interface of WebUI, and knowledge management interface of WebUI are list below respectively.
|
||||
|
||||
1. FastAPI docs
|
||||
|
||||

|
||||
|
||||
2. Chat Interface of WebUI
|
||||
|
||||
- Dialogue interface of WebUI
|
||||
|
||||

|
||||
|
||||
- Knowledge management interface of WebUI
|
||||
|
||||

|
||||
|
||||
## FAQ
|
||||
|
||||
Please refer to [FAQ](docs/FAQ.md)
|
||||
The above instructions are provided for a quick start. If you need more features or want to customize the launch method,
|
||||
please refer to the [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/).
|
||||
|
||||
---
|
||||
|
||||
## Roadmap
|
||||
## Contact Us
|
||||
|
||||
- [X] Langchain applications
|
||||
### Telegram
|
||||
|
||||
- [X] Load local documents
|
||||
- [X] Unstructured documents
|
||||
- [X] .md
|
||||
- [X] .txt
|
||||
- [X] .docx
|
||||
- [ ] Structured documents
|
||||
- [X] .csv
|
||||
- [ ] .xlsx
|
||||
- [] TextSplitter and Retriever
|
||||
- [X] multiple TextSplitter
|
||||
- [X] ChineseTextSplitter
|
||||
- [ ] Reconstructed Context Retriever
|
||||
- [ ] Webpage
|
||||
- [ ] SQL
|
||||
- [ ] Knowledge Database
|
||||
- [X] Search Engines
|
||||
- [X] Bing
|
||||
- [X] DuckDuckGo
|
||||
- [X] Agent
|
||||
- [X] Agent implementation in the form of basic React, including calls to calculators, etc.
|
||||
- [X] Langchain's own Agent implementation and calls
|
||||
- [ ] More Agent support for models
|
||||
- [ ] More tools
|
||||
- [X] LLM Models
|
||||
- [X] [FastChat](https://github.com/lm-sys/fastchat) -based LLM Models
|
||||
- [ ] Mutiply Remote LLM API
|
||||
- [X] Embedding Models
|
||||
- [X] HuggingFace -based Embedding models
|
||||
- [ ] Mutiply Remote Embedding API
|
||||
- [X] FastAPI-based API
|
||||
- [X] Web UI
|
||||
- [X] Streamlit -based Web UI
|
||||
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
||||
|
||||
---
|
||||
### WeChat Group、
|
||||
|
||||
## Wechat Group
|
||||
<img src="img/qr_code_67.jpg" alt="二维码" width="300" height="300" />
|
||||
|
||||
<img src="img/qr_code_64.jpg" alt="QR Code" width="300" height="300" />
|
||||
### WeChat Official Account
|
||||
|
||||
🎉 langchain-Chatchat project WeChat exchange group, if you are also interested in this project, welcome to join the group chat to participate in the discussion and exchange.
|
||||
|
||||
## Follow us
|
||||
|
||||
<img src="img/official_account.png" alt="image" width="900" height="300" />
|
||||
🎉 langchain-Chatchat project official public number, welcome to scan the code to follow.
|
||||
<img src="img/official_wechat_mp_account.png" alt="图片" width="900" height="300" />
|
||||
|
||||
@ -5,4 +5,4 @@ from .server_config import *
|
||||
from .prompt_config import *
|
||||
|
||||
|
||||
VERSION = "v0.2.5"
|
||||
VERSION = "v0.2.6"
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import os
|
||||
|
||||
# 默认使用的知识库
|
||||
DEFAULT_KNOWLEDGE_BASE = "samples"
|
||||
|
||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
||||
# 默认向量库类型。可选:faiss, milvus(离线) & zilliz(在线), pg.
|
||||
DEFAULT_VS_TYPE = "faiss"
|
||||
|
||||
# 缓存向量库数量(针对FAISS)
|
||||
@ -19,6 +21,9 @@ VECTOR_SEARCH_TOP_K = 3
|
||||
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
||||
SCORE_THRESHOLD = 1
|
||||
|
||||
# 默认搜索引擎。可选:bing, duckduckgo, metaphor
|
||||
DEFAULT_SEARCH_ENGINE = "duckduckgo"
|
||||
|
||||
# 搜索引擎匹配结题数量
|
||||
SEARCH_ENGINE_TOP_K = 3
|
||||
|
||||
@ -36,19 +41,26 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
||||
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
||||
BING_SUBSCRIPTION_KEY = ""
|
||||
|
||||
# metaphor搜索需要KEY
|
||||
METAPHOR_API_KEY = ""
|
||||
|
||||
|
||||
# 是否开启中文标题加强,以及标题增强的相关配置
|
||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||
ZH_TITLE_ENHANCE = False
|
||||
|
||||
|
||||
# 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。
|
||||
KB_INFO = {
|
||||
"知识库名称": "知识库介绍",
|
||||
"samples": "关于本项目issue的解答",
|
||||
}
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# 知识库默认存储路径
|
||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||
if not os.path.exists(KB_ROOT_PATH):
|
||||
os.mkdir(KB_ROOT_PATH)
|
||||
|
||||
# 数据库默认存储路径。
|
||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
@ -65,6 +77,13 @@ kbs_config = {
|
||||
"password": "",
|
||||
"secure": False,
|
||||
},
|
||||
"zilliz": {
|
||||
"host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn",
|
||||
"port": "19530",
|
||||
"user": "",
|
||||
"password": "",
|
||||
"secure": True,
|
||||
},
|
||||
"pg": {
|
||||
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
|
||||
}
|
||||
@ -74,11 +93,11 @@ kbs_config = {
|
||||
text_splitter_dict = {
|
||||
"ChineseRecursiveTextSplitter": {
|
||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||
"tokenizer_name_or_path": "gpt2",
|
||||
"tokenizer_name_or_path": "",
|
||||
},
|
||||
"SpacyTextSplitter": {
|
||||
"source": "huggingface",
|
||||
"tokenizer_name_or_path": "",
|
||||
"tokenizer_name_or_path": "gpt2",
|
||||
},
|
||||
"RecursiveCharacterTextSplitter": {
|
||||
"source": "tiktoken",
|
||||
@ -44,7 +44,7 @@ MODEL_PATH = {
|
||||
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
|
||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||
|
||||
"baichuan2-13b": "baichuan-inc/Baichuan-13B-Chat",
|
||||
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
|
||||
"baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat",
|
||||
|
||||
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||
@ -90,9 +90,8 @@ MODEL_PATH = {
|
||||
"Qwen-14B-Chat":"Qwen/Qwen-14B-Chat",
|
||||
},
|
||||
}
|
||||
|
||||
# 选用的 Embedding 名称
|
||||
EMBEDDING_MODEL = "m3e-base" # 可以尝试最新的嵌入式sota模型:piccolo-large-zh
|
||||
EMBEDDING_MODEL = "m3e-base" # 可以尝试最新的嵌入式sota模型:bge-large-zh-v1.5
|
||||
|
||||
|
||||
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
@ -112,7 +111,8 @@ TEMPERATURE = 0.7
|
||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||
|
||||
|
||||
ONLINE_LLM_MODEL = {
|
||||
LANGCHAIN_LLM_MODEL = {
|
||||
# 不需要走Fschat封装的,Langchain直接支持的模型。
|
||||
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
||||
# Max retries exceeded with url: /v1/chat/completions
|
||||
# 则需要将urllib3版本修改为1.25.11
|
||||
@ -128,11 +128,29 @@ ONLINE_LLM_MODEL = {
|
||||
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
|
||||
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
||||
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
|
||||
"gpt-3.5-turbo": {
|
||||
|
||||
# 这些配置文件的名字不能改动
|
||||
"Azure-OpenAI": {
|
||||
"deployment_name": "your Azure deployment name",
|
||||
"model_version": "0701",
|
||||
"openai_api_type": "azure",
|
||||
"api_base_url": "https://your Azure point.azure.com",
|
||||
"api_version": "2023-07-01-preview",
|
||||
"api_key": "your Azure api key",
|
||||
"openai_proxy": "",
|
||||
},
|
||||
"OpenAI": {
|
||||
"model_name": "your openai model name(such as gpt-4)",
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"api_key": "your OPENAI_API_KEY",
|
||||
"openai_proxy": "your OPENAI_PROXY",
|
||||
"openai_proxy": "",
|
||||
},
|
||||
"Anthropic": {
|
||||
"model_name": "your claude model name(such as claude2-100k)",
|
||||
"api_key":"your ANTHROPIC_API_KEY",
|
||||
}
|
||||
}
|
||||
ONLINE_LLM_MODEL = {
|
||||
# 线上模型。请在server_config中为每个在线API设置不同的端口
|
||||
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||
"zhipu-api": {
|
||||
|
||||
@ -9,41 +9,106 @@
|
||||
# - context: 从检索结果拼接的知识文本
|
||||
# - question: 用户提出的问题
|
||||
|
||||
# Agent对话支持的变量:
|
||||
|
||||
PROMPT_TEMPLATES = {
|
||||
# LLM对话模板
|
||||
"llm_chat": "{{ input }}",
|
||||
# - tools: 可用的工具列表
|
||||
# - tool_names: 可用的工具名称列表
|
||||
# - history: 用户和Agent的对话历史
|
||||
# - input: 用户输入内容
|
||||
# - agent_scratchpad: Agent的思维记录
|
||||
|
||||
# 基于本地知识问答的提示词模板
|
||||
"knowledge_base_chat":
|
||||
PROMPT_TEMPLATES = {}
|
||||
|
||||
PROMPT_TEMPLATES["llm_chat"] = {
|
||||
"default": "{{ input }}",
|
||||
|
||||
"py":
|
||||
"""
|
||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>""",
|
||||
|
||||
# 基于agent的提示词模板
|
||||
"agent_chat":
|
||||
你是一个聪明的代码助手,请你给我写出简单的py代码。 \n
|
||||
{{ input }}
|
||||
"""
|
||||
Answer the following questions as best you can. You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
Use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
Begin!
|
||||
|
||||
history:
|
||||
{history}
|
||||
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
,
|
||||
}
|
||||
|
||||
PROMPT_TEMPLATES["knowledge_base_chat"] = {
|
||||
"default":
|
||||
"""
|
||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
"text":
|
||||
"""
|
||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
}
|
||||
PROMPT_TEMPLATES["search_engine_chat"] = {
|
||||
"default":
|
||||
"""
|
||||
<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
|
||||
"search":
|
||||
"""
|
||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
}
|
||||
PROMPT_TEMPLATES["agent_chat"] = {
|
||||
"default":
|
||||
"""
|
||||
Answer the following questions as best you can. If it is in order, you can use some tools appropriately.You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
Please note that the "知识库查询工具" is information about the "西交利物浦大学" ,and if a question is asked about it, you must answer with the knowledge base,
|
||||
Please note that the "天气查询工具" can only be used once since Question begin.
|
||||
|
||||
Use the following format:
|
||||
Question: the input question you must answer1
|
||||
Thought: you should always think about what to do and what tools to use.
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
|
||||
Begin!
|
||||
history:
|
||||
{history}
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
""",
|
||||
"ChatGLM":
|
||||
"""
|
||||
请请严格按照提供的思维方式来思考。你的知识不一定正确,所以你一定要用提供的工具来思考,并给出用户答案。
|
||||
你有以下工具可以使用:
|
||||
{tools}
|
||||
```
|
||||
Question: 用户的提问或者观察到的信息,
|
||||
Thought: 你应该思考该做什么,是根据工具的结果来回答问题,还是决定使用什么工具。
|
||||
Action: 需要使用的工具,应该是在[{tool_names}]中的一个。
|
||||
Action Input: 传入工具的内容
|
||||
Observation: 工具给出的答案(不是你生成的)
|
||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||
Thought: 通过工具给出的答案,你是否能回答Question。
|
||||
Final Answer是你的答案
|
||||
|
||||
现在,我们开始!
|
||||
你和用户的历史记录:
|
||||
History:
|
||||
{history}
|
||||
|
||||
用户开始以提问:
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
|
||||
""",
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ HTTPX_DEFAULT_TIMEOUT = 300.0
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
||||
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||||
DEFAULT_BIND_HOST = "0.0.0.0"
|
||||
DEFAULT_BIND_HOST = "0.0.0.0" if sys.platform != "win32" else "127.0.0.1"
|
||||
|
||||
# webui.py server
|
||||
WEBUI_SERVER = {
|
||||
@ -32,6 +32,7 @@ FSCHAT_OPENAI_API = {
|
||||
# fastchat model_worker server
|
||||
# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
|
||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||
# 必须在这里添加的模型才会出现在WEBUI中可选模型列表里(LLM_MODEL会自动添加)
|
||||
FSCHAT_MODEL_WORKERS = {
|
||||
# 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
|
||||
"default": {
|
||||
@ -39,7 +40,8 @@ FSCHAT_MODEL_WORKERS = {
|
||||
"port": 20002,
|
||||
"device": LLM_DEVICE,
|
||||
# False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题,参见doc/FAQ
|
||||
"infer_turbo": "vllm" if sys.platform.startswith("linux") else False,
|
||||
# vllm对一些模型支持还不成熟,暂时默认关闭
|
||||
"infer_turbo": False,
|
||||
|
||||
# model_worker多卡加载需要配置的参数
|
||||
# "gpus": None, # 使用的GPU,以str的格式指定,如"0,1",如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定
|
||||
@ -97,21 +99,24 @@ FSCHAT_MODEL_WORKERS = {
|
||||
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口
|
||||
"port": 21001,
|
||||
},
|
||||
"minimax-api": {
|
||||
"port": 21002,
|
||||
},
|
||||
"xinghuo-api": {
|
||||
"port": 21003,
|
||||
},
|
||||
"qianfan-api": {
|
||||
"port": 21004,
|
||||
},
|
||||
"fangzhou-api": {
|
||||
"port": 21005,
|
||||
},
|
||||
"qwen-api": {
|
||||
"port": 21006,
|
||||
},
|
||||
# "minimax-api": {
|
||||
# "port": 21002,
|
||||
# },
|
||||
# "xinghuo-api": {
|
||||
# "port": 21003,
|
||||
# },
|
||||
# "qianfan-api": {
|
||||
# "port": 21004,
|
||||
# },
|
||||
# "fangzhou-api": {
|
||||
# "port": 21005,
|
||||
# },
|
||||
# "qwen-api": {
|
||||
# "port": 21006,
|
||||
# },
|
||||
# "baichuan-api": {
|
||||
# "port": 21007,
|
||||
# },
|
||||
}
|
||||
|
||||
# fastchat multi model worker server
|
||||
|
||||
@ -1,32 +0,0 @@
|
||||
## 变更日志
|
||||
|
||||
**[2023/04/15]**
|
||||
|
||||
1. 重构项目结构,在根目录下保留命令行 Demo [cli_demo.py](../cli_demo.py) 和 Web UI Demo [webui.py](../webui.py);
|
||||
2. 对 Web UI 进行改进,修改为运行 Web UI 后首先按照 [configs/model_config.py](../configs/model_config.py) 默认选项加载模型,并增加报错提示信息等;
|
||||
3. 对常见问题进行补充说明。
|
||||
|
||||
**[2023/04/12]**
|
||||
|
||||
1. 替换 Web UI 中的样例文件,避免出现 Ubuntu 中出现因文件编码无法读取的问题;
|
||||
2. 替换`knowledge_based_chatglm.py`中的 prompt 模版,避免出现因 prompt 模版包含中英双语导致 chatglm 返回内容错乱的问题。
|
||||
|
||||
**[2023/04/11]**
|
||||
|
||||
1. 加入 Web UI V0.1 版本(感谢 [@liangtongt](https://github.com/liangtongt));
|
||||
2. `README.md`中增加常见问题(感谢 [@calcitem](https://github.com/calcitem) 和 [@bolongliu](https://github.com/bolongliu));
|
||||
3. 增加 LLM 和 Embedding 模型运行设备是否可用`cuda`、`mps`、`cpu`的自动判断。
|
||||
4. 在`knowledge_based_chatglm.py`中增加对`filepath`的判断,在之前支持单个文件导入的基础上,现支持单个文件夹路径作为输入,输入后将会遍历文件夹中各个文件,并在命令行中显示每个文件是否成功加载。
|
||||
|
||||
**[2023/04/09]**
|
||||
|
||||
1. 使用`langchain`中的`RetrievalQA`替代之前选用的`ChatVectorDBChain`,替换后可以有效减少提问 2-3 次后因显存不足而停止运行的问题;
|
||||
2. 在`knowledge_based_chatglm.py`中增加`EMBEDDING_MODEL`、`VECTOR_SEARCH_TOP_K`、`LLM_MODEL`、`LLM_HISTORY_LEN`、`REPLY_WITH_SOURCE`参数值设置;
|
||||
3. 增加 GPU 显存需求更小的`chatglm-6b-int4`、`chatglm-6b-int4-qe`作为 LLM 模型备选项;
|
||||
4. 更正`README.md`中的代码错误(感谢 [@calcitem](https://github.com/calcitem))。
|
||||
|
||||
**[2023/04/07]**
|
||||
|
||||
1. 解决加载 ChatGLM 模型时发生显存占用为双倍的问题 (感谢 [@suc16](https://github.com/suc16) 和 [@myml](https://github.com/myml)) ;
|
||||
2. 新增清理显存机制;
|
||||
3. 新增`nghuyong/ernie-3.0-nano-zh`和`nghuyong/ernie-3.0-base-zh`作为 Embedding 模型备选项,相比`GanymedeNil/text2vec-large-chinese`占用显存资源更少 (感谢 [@lastrei](https://github.com/lastrei))。
|
||||
223
docs/FAQ.md
@ -1,223 +0,0 @@
|
||||
### 常见问题
|
||||
|
||||
Q1: 本项目支持哪些文件格式?
|
||||
|
||||
A1: 目前已测试支持 txt、docx、md、pdf 格式文件,更多文件格式请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)。目前已知文档中若含有特殊字符,可能存在文件无法加载的问题。
|
||||
|
||||
---
|
||||
|
||||
Q2: 使用过程中 Python 包 `nltk`发生了 `Resource punkt not found.`报错,该如何解决?
|
||||
|
||||
A2: 方法一:https://github.com/nltk/nltk_data/raw/gh-pages/packages/tokenizers/punkt.zip 中的 `packages/tokenizers` 解压,放到 `nltk_data/tokenizers` 存储路径下。
|
||||
|
||||
`nltk_data` 存储路径可以通过 `nltk.data.path` 查询。
|
||||
|
||||
方法二:执行python代码
|
||||
|
||||
```
|
||||
import nltk
|
||||
nltk.download()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Q3: 使用过程中 Python 包 `nltk`发生了 `Resource averaged_perceptron_tagger not found.`报错,该如何解决?
|
||||
|
||||
A3: 方法一:将 https://github.com/nltk/nltk_data/blob/gh-pages/packages/taggers/averaged_perceptron_tagger.zip 下载,解压放到 `nltk_data/taggers` 存储路径下。
|
||||
|
||||
`nltk_data` 存储路径可以通过 `nltk.data.path` 查询。
|
||||
|
||||
方法二:执行python代码
|
||||
|
||||
```
|
||||
import nltk
|
||||
nltk.download()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Q4: 本项目可否在 colab 中运行?
|
||||
|
||||
A4: 可以尝试使用 chatglm-6b-int4 模型在 colab 中运行,需要注意的是,如需在 colab 中运行 Web UI,需将 `webui.py`中 `demo.queue(concurrency_count=3).launch( server_name='0.0.0.0', share=False, inbrowser=False)`中参数 `share`设置为 `True`。
|
||||
|
||||
---
|
||||
|
||||
Q5: 在 Anaconda 中使用 pip 安装包无效如何解决?
|
||||
|
||||
A5: 此问题是系统环境问题,详细见 [在Anaconda中使用pip安装包无效问题](在Anaconda中使用pip安装包无效问题.md)
|
||||
|
||||
---
|
||||
|
||||
Q6: 本项目中所需模型如何下载至本地?
|
||||
|
||||
A6: 本项目中使用的模型均为 `huggingface.com`中可下载的开源模型,以默认选择的 `chatglm-6b`和 `text2vec-large-chinese`模型为例,下载模型可执行如下代码:
|
||||
|
||||
```shell
|
||||
# 安装 git lfs
|
||||
$ git lfs install
|
||||
|
||||
# 下载 LLM 模型
|
||||
$ git clone https://huggingface.co/THUDM/chatglm-6b /your_path/chatglm-6b
|
||||
|
||||
# 下载 Embedding 模型
|
||||
$ git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese /your_path/text2vec
|
||||
|
||||
# 模型需要更新时,可打开模型所在文件夹后拉取最新模型文件/代码
|
||||
$ git pull
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Q7: `huggingface.com`中模型下载速度较慢怎么办?
|
||||
|
||||
A7: 可使用本项目用到的模型权重文件百度网盘地址:
|
||||
|
||||
- ernie-3.0-base-zh.zip 链接: https://pan.baidu.com/s/1CIvKnD3qzE-orFouA8qvNQ?pwd=4wih
|
||||
- ernie-3.0-nano-zh.zip 链接: https://pan.baidu.com/s/1Fh8fgzVdavf5P1omAJJ-Zw?pwd=q6s5
|
||||
- text2vec-large-chinese.zip 链接: https://pan.baidu.com/s/1sMyPzBIXdEzHygftEoyBuA?pwd=4xs7
|
||||
- chatglm-6b-int4-qe.zip 链接: https://pan.baidu.com/s/1DDKMOMHtNZccOOBGWIOYww?pwd=22ji
|
||||
- chatglm-6b-int4.zip 链接: https://pan.baidu.com/s/1pvZ6pMzovjhkA6uPcRLuJA?pwd=3gjd
|
||||
- chatglm-6b.zip 链接: https://pan.baidu.com/s/1B-MpsVVs1GHhteVBetaquw?pwd=djay
|
||||
|
||||
---
|
||||
|
||||
Q8: 下载完模型后,如何修改代码以执行本地模型?
|
||||
|
||||
A8: 模型下载完成后,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对 `embedding_model_dict`和 `llm_model_dict`参数进行修改,如把 `llm_model_dict`从
|
||||
|
||||
```python
|
||||
embedding_model_dict = {
|
||||
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
||||
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
||||
"text2vec": "GanymedeNil/text2vec-large-chinese"
|
||||
}
|
||||
```
|
||||
|
||||
修改为
|
||||
|
||||
```python
|
||||
embedding_model_dict = {
|
||||
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
||||
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
||||
"text2vec": "/Users/liuqian/Downloads/ChatGLM-6B/text2vec-large-chinese"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Q9: 执行 `python cli_demo.py`过程中,显卡内存爆了,提示 "OutOfMemoryError: CUDA out of memory"
|
||||
|
||||
A9: 将 `VECTOR_SEARCH_TOP_K` 和 `LLM_HISTORY_LEN` 的值调低,比如 `VECTOR_SEARCH_TOP_K = 5` 和 `LLM_HISTORY_LEN = 2`,这样由 `query` 和 `context` 拼接得到的 `prompt` 会变短,会减少内存的占用。或者打开量化,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对 `LOAD_IN_8BIT`参数进行修改
|
||||
|
||||
---
|
||||
|
||||
Q10: 执行 `pip install -r requirements.txt` 过程中遇到 python 包,如 langchain 找不到对应版本的问题
|
||||
|
||||
A10: 更换 pypi 源后重新安装,如阿里源、清华源等,网络条件允许时建议直接使用 pypi.org 源,具体操作命令如下:
|
||||
|
||||
```shell
|
||||
# 使用 pypi 源
|
||||
$ pip install -r requirements.txt -i https://pypi.python.org/simple
|
||||
```
|
||||
|
||||
或
|
||||
|
||||
```shell
|
||||
# 使用阿里源
|
||||
$ pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/
|
||||
```
|
||||
|
||||
或
|
||||
|
||||
```shell
|
||||
# 使用清华源
|
||||
$ pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Q11: 启动 api.py 时 upload_file 接口抛出 `partially initialized module 'charset_normalizer' has no attribute 'md__mypyc' (most likely due to a circular import)`
|
||||
|
||||
A11: 这是由于 charset_normalizer 模块版本过高导致的,需要降低低 charset_normalizer 的版本,测试在 charset_normalizer==2.1.0 上可用。
|
||||
|
||||
---
|
||||
|
||||
Q12: 调用api中的 `bing_search_chat` 接口时,报出 `Failed to establish a new connection: [Errno 110] Connection timed out`
|
||||
|
||||
A12: 这是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG--!
|
||||
|
||||
---
|
||||
|
||||
Q13: 加载 chatglm-6b-int8 或 chatglm-6b-int4 抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`
|
||||
|
||||
A13: 疑为 chatglm 的 quantization 的问题或 torch 版本差异问题,针对已经变为 Parameter 的 torch.zeros 矩阵也执行 Parameter 操作,从而抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在 chatglm 项目的原始文件中的 quantization.py 文件 374 行改为:
|
||||
|
||||
```
|
||||
try:
|
||||
self.weight =Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
||||
except Exception as e:
|
||||
pass
|
||||
```
|
||||
|
||||
如果上述方式不起作用,则在.cache/hugggingface/modules/目录下针对chatglm项目的原始文件中的quantization.py文件执行上述操作,若软链接不止一个,按照错误提示选择正确的路径。
|
||||
|
||||
注:虽然模型可以顺利加载但在cpu上仍存在推理失败的可能:即针对每个问题,模型一直输出gugugugu。
|
||||
|
||||
因此,最好不要试图用cpu加载量化模型,原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。
|
||||
|
||||
---
|
||||
|
||||
Q14: 修改配置中路径后,加载 text2vec-large-chinese 依然提示 `WARNING: No sentence-transformers model found with name text2vec-large-chinese. Creating a new one with MEAN pooling.`
|
||||
|
||||
A14: 尝试更换 embedding,如 text2vec-base-chinese,请在 [configs/model_config.py](../configs/model_config.py) 文件中,修改 `text2vec-base`参数为本地路径,绝对路径或者相对路径均可
|
||||
|
||||
---
|
||||
|
||||
Q15: 使用pg向量库建表报错
|
||||
|
||||
A15: 需要手动安装对应的vector扩展(连接pg执行 CREATE EXTENSION IF NOT EXISTS vector)
|
||||
|
||||
---
|
||||
|
||||
Q16: pymilvus 连接超时
|
||||
|
||||
A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3
|
||||
|
||||
Q16: 使用vllm推理加速框架时,已经下载了模型但出现HuggingFace通信问题
|
||||
|
||||
A16: 参照如下代码修改python环境下/site-packages/vllm/model_executor/weight_utils.py文件的prepare_hf_model_weights函数如下对应代码:
|
||||
|
||||
```python
|
||||
|
||||
if not is_local:
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
model_path_temp = os.path.join(
|
||||
os.getenv("HOME"),
|
||||
".cache/huggingface/hub",
|
||||
"models--" + model_name_or_path.replace("/", "--"),
|
||||
"snapshots/",
|
||||
)
|
||||
downloaded = False
|
||||
if os.path.exists(model_path_temp):
|
||||
temp_last_dir = os.listdir(model_path_temp)[-1]
|
||||
model_path_temp = os.path.join(model_path_temp, temp_last_dir)
|
||||
base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin")
|
||||
files = glob.glob(base_pattern)
|
||||
if len(files) > 0:
|
||||
downloaded = True
|
||||
|
||||
if downloaded:
|
||||
hf_folder = model_path_temp
|
||||
else:
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
|
||||
|
||||
```
|
||||
@ -1,63 +0,0 @@
|
||||
# 安装
|
||||
|
||||
## 环境检查
|
||||
|
||||
```shell
|
||||
# 首先,确信你的机器安装了 Python 3.8 - 3.10 版本
|
||||
$ python --version
|
||||
Python 3.8.13
|
||||
|
||||
# 如果低于这个版本,可使用conda安装环境
|
||||
$ conda create -p /your_path/env_name python=3.8
|
||||
|
||||
# 激活环境
|
||||
$ source activate /your_path/env_name
|
||||
|
||||
# 或,conda安装,不指定路径, 注意以下,都将/your_path/env_name替换为env_name
|
||||
$ conda create -n env_name python=3.8
|
||||
$ conda activate env_name # Activate the environment
|
||||
|
||||
# 更新py库
|
||||
$ pip3 install --upgrade pip
|
||||
|
||||
# 关闭环境
|
||||
$ source deactivate /your_path/env_name
|
||||
|
||||
# 删除环境
|
||||
$ conda env remove -p /your_path/env_name
|
||||
```
|
||||
|
||||
## 项目依赖
|
||||
|
||||
```shell
|
||||
# 拉取仓库
|
||||
$ git clone https://github.com/chatchat-space/Langchain-Chatchat.git
|
||||
|
||||
# 进入目录
|
||||
$ cd Langchain-Chatchat
|
||||
|
||||
# 安装全部依赖
|
||||
$ pip install -r requirements.txt
|
||||
|
||||
# 默认依赖包括基本运行环境(FAISS向量库)。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
|
||||
```
|
||||
|
||||
此外,为方便用户 API 与 webui 分离运行,可单独根据运行需求安装依赖包。
|
||||
|
||||
- 如果只需运行 API,可执行:
|
||||
```shell
|
||||
$ pip install -r requirements_api.txt
|
||||
|
||||
# 默认依赖包括基本运行环境(FAISS向量库)。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
|
||||
```
|
||||
|
||||
- 如果只需运行 WebUI,可执行:
|
||||
```shell
|
||||
$ pip install -r requirements_webui.txt
|
||||
```
|
||||
|
||||
|
||||
|
||||
注:使用 `langchain.document_loaders.UnstructuredFileLoader` 进行 `.docx` 等格式非结构化文件接入时,可能需要依据文档进行其他依赖包的安装,请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)。
|
||||
|
||||
|
||||
@ -1,114 +0,0 @@
|
||||
## Issue with Installing Packages Using pip in Anaconda
|
||||
|
||||
## Problem
|
||||
|
||||
Recently, when running open-source code, I encountered an issue: after creating a virtual environment with conda and switching to the new environment, using pip to install packages would be "ineffective." Here, "ineffective" means that the packages installed with pip are not in this new environment.
|
||||
|
||||
------
|
||||
|
||||
## Analysis
|
||||
|
||||
1. First, create a test environment called test: `conda create -n test`
|
||||
2. Activate the test environment: `conda activate test`
|
||||
3. Use pip to install numpy: `pip install numpy`. You'll find that numpy already exists in the default environment.
|
||||
|
||||
```powershell
|
||||
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
Requirement already satisfied: numpy in c:\programdata\anaconda3\lib\site-packages (1.20.3)
|
||||
```
|
||||
|
||||
4. Check the information of pip: `pip show pip`
|
||||
|
||||
```powershell
|
||||
Name: pip
|
||||
Version: 21.2.4
|
||||
Summary: The PyPA recommended tool for installing Python packages.
|
||||
Home-page: https://pip.pypa.io/
|
||||
Author: The pip developers
|
||||
Author-email: distutils-sig@python.org
|
||||
License: MIT
|
||||
Location: c:\programdata\anaconda3\lib\site-packages
|
||||
Requires:
|
||||
Required-by:
|
||||
```
|
||||
|
||||
5. We can see that the current pip is in the default conda environment. This explains why the package is not in the new virtual environment when we directly use pip to install packages - because the pip being used belongs to the default environment, the installed package either already exists or is installed directly into the default environment.
|
||||
|
||||
------
|
||||
|
||||
## Solution
|
||||
|
||||
1. We can directly use the conda command to install new packages, but sometimes conda may not have certain packages/libraries, so we still need to use pip to install.
|
||||
2. We can first use the conda command to install the pip package for the current virtual environment, and then use pip to install new packages.
|
||||
|
||||
```powershell
|
||||
# Use conda to install the pip package
|
||||
(test) PS C:\Users\Administrator> conda install pip
|
||||
Collecting package metadata (current_repodata.json): done
|
||||
Solving environment: done
|
||||
....
|
||||
done
|
||||
|
||||
# Display the information of the current pip, and find that pip is in the test environment
|
||||
(test) PS C:\Users\Administrator> pip show pip
|
||||
Name: pip
|
||||
Version: 21.2.4
|
||||
Summary: The PyPA recommended tool for installing Python packages.
|
||||
Home-page: https://pip.pypa.io/
|
||||
Author: The pip developers
|
||||
Author-email: distutils-sig@python.org
|
||||
License: MIT
|
||||
Location: c:\programdata\anaconda3\envs\test\lib\site-packages
|
||||
Requires:
|
||||
Required-by:
|
||||
|
||||
# Now use pip to install the numpy package, and it is installed successfully
|
||||
(test) PS C:\Users\Administrator> pip install numpy
|
||||
Looking in indexes:
|
||||
https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
Collecting numpy
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/4b/23/140ec5a509d992fe39db17200e96c00fd29603c1531ce633ef93dbad5e9e/numpy-1.22.2-cp39-cp39-win_amd64.whl (14.7 MB)
|
||||
Installing collected packages: numpy
|
||||
Successfully installed numpy-1.22.2
|
||||
|
||||
# Use pip list to view the currently installed packages, no problem
|
||||
(test) PS C:\Users\Administrator> pip list
|
||||
Package Version
|
||||
------------ ---------
|
||||
certifi 2021.10.8
|
||||
numpy 1.22.2
|
||||
pip 21.2.4
|
||||
setuptools 58.0.4
|
||||
wheel 0.37.1
|
||||
wincertstore 0.2
|
||||
```
|
||||
|
||||
## Supplement
|
||||
|
||||
1. The reason I didn't notice this problem before might be because the packages installed in the virtual environment were of a specific version, which overwrote the packages in the default environment. The main issue was actually a lack of careful observation:), otherwise, I could have noticed `Successfully uninstalled numpy-xxx` **default version** and `Successfully installed numpy-1.20.3` **specified version**.
|
||||
2. During testing, I found that if the Python version is specified when creating a new package, there shouldn't be this issue. I guess this is because pip will be installed in the virtual environment, while in our case, including pip, no packages were installed, so the default environment's pip was used.
|
||||
3. There's a question: I should have specified the Python version when creating a new virtual environment before, but I still used the default environment's pip package. However, I just couldn't reproduce the issue successfully on two different machines, which led to the second point mentioned above.
|
||||
4. After encountering the problem mentioned in point 3, I solved it by using `python -m pip install package-name`, adding `python -m` before pip. As for why, you can refer to the answer on [StackOverflow](https://stackoverflow.com/questions/41060382/using-pip-to-install-packages-to-anaconda-environment):
|
||||
|
||||
>1. If you have a non-conda pip as your default pip but conda python as your default python (as below):
|
||||
>
|
||||
>```shell
|
||||
>>which -a pip
|
||||
>/home/<user>/.local/bin/pip
|
||||
>/home/<user>/.conda/envs/newenv/bin/pip
|
||||
>/usr/bin/pip
|
||||
>
|
||||
>>which -a python
|
||||
>/home/<user>/.conda/envs/newenv/bin/python
|
||||
>/usr/bin/python
|
||||
>```
|
||||
>
|
||||
>2. Then, instead of calling `pip install <package>` directly, you can use the module flag -m in python so that it installs with the anaconda python
|
||||
>
|
||||
>```shell
|
||||
>python -m pip install <package>
|
||||
>```
|
||||
>
|
||||
>3. This will install the package to the anaconda library directory rather than the library directory associated with the (non-anaconda) pip
|
||||
>4. The reason for doing this is as follows: the pip command references a specific pip file/shortcut (which -a pip will tell you which one). Similarly, the python command references a specific python file (which -a python will tell you which one). For one reason or another, these two commands can become out of sync, so your "default" pip is in a different folder than your default python and therefore is associated with different versions of python.
|
||||
>5. In contrast, the python -m pip construct does not use the shortcut that the pip command points to. Instead, it asks python to find its pip version and use that version to install a package.
|
||||
@ -1,49 +0,0 @@
|
||||
version: '3.5'
|
||||
|
||||
services:
|
||||
etcd:
|
||||
container_name: milvus-etcd
|
||||
image: quay.io/coreos/etcd:v3.5.0
|
||||
environment:
|
||||
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||
- ETCD_SNAPSHOT_COUNT=50000
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
|
||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||
|
||||
minio:
|
||||
container_name: milvus-minio
|
||||
image: minio/minio:RELEASE.2022-03-17T06-34-49Z
|
||||
environment:
|
||||
MINIO_ACCESS_KEY: minioadmin
|
||||
MINIO_SECRET_KEY: minioadmin
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
|
||||
command: minio server /minio_data
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.1.3
|
||||
command: ["milvus", "run", "standalone"]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: etcd:2379
|
||||
MINIO_ADDRESS: minio:9000
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
|
||||
ports:
|
||||
- "19530:19530"
|
||||
- "9091:9091"
|
||||
depends_on:
|
||||
- "etcd"
|
||||
- "minio"
|
||||
|
||||
networks:
|
||||
default:
|
||||
name: milvus
|
||||
@ -1,13 +0,0 @@
|
||||
version: "3.8"
|
||||
services:
|
||||
postgresql:
|
||||
image: ankane/pgvector:v0.4.1
|
||||
container_name: langchain_chatchat-pg-db
|
||||
environment:
|
||||
POSTGRES_DB: langchain_chatchat
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
ports:
|
||||
- 5432:5432
|
||||
volumes:
|
||||
- ./data:/var/lib/postgresql/data
|
||||
@ -1,24 +0,0 @@
|
||||
## 如何自定义分词器
|
||||
|
||||
### 在哪里写,哪些文件要改
|
||||
1. 在```text_splitter```文件夹下新建一个文件,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器,如下所示:
|
||||
```python
|
||||
from .my_splitter import MySplitter
|
||||
```
|
||||
|
||||
2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示:
|
||||
```python
|
||||
MySplitter: {
|
||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||
"tokenizer_name_or_path": "your tokenizer", #如果选择huggingface则使用huggingface的方法,部分tokenizer需要从Huggingface下载
|
||||
}
|
||||
TEXT_SPLITTER = "MySplitter"
|
||||
```
|
||||
|
||||
完成上述步骤后,就能使用自己的分词器了。
|
||||
|
||||
### 如何贡献您的分词器
|
||||
|
||||
1. 将您的分词器所在的代码文件放在```text_splitter```文件夹下,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器。
|
||||
2. 发起PR,并说明您的分词器面向的场景或者改进之处。我们非常期待您能举例一个具体的应用场景。
|
||||
3. 在Readme.md中添加您的分词器的使用方法和支持说明。
|
||||
@ -1,8 +0,0 @@
|
||||
向量库环境 docker-compose.yml 文件在 docs/docker/vector_db 中
|
||||
|
||||
以 milvus 为例
|
||||
```shell
|
||||
cd docs/docker/vector_db/milvus
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
@ -1,37 +0,0 @@
|
||||
# 启动API服务
|
||||
|
||||
## 通过py文件启动
|
||||
可以通过直接执行`api.py`文件启动API服务,默认以ip:0.0.0.0和port:7861启动http和ws服务。
|
||||
```shell
|
||||
python api.py
|
||||
```
|
||||
同时,启动时支持StartOption所列的模型加载参数,同时还支持IP和端口设置。
|
||||
```shell
|
||||
python api.py --model-name chatglm-6b-int8 --port 7862
|
||||
```
|
||||
|
||||
## 通过cli.bat/cli.sh启动
|
||||
也可以通过命令行控制文件继续启动。
|
||||
```shell
|
||||
cli.sh api --help
|
||||
```
|
||||
其他可设置参数和上述py文件启动方式相同。
|
||||
|
||||
|
||||
# 以https、wss启动API服务
|
||||
## 本地创建ssl相关证书文件
|
||||
如果没有正式签发的CA证书,可以[安装mkcert](https://github.com/FiloSottile/mkcert#installation)工具, 然后用如下指令生成本地CA证书:
|
||||
```shell
|
||||
mkcert -install
|
||||
mkcert api.example.com 47.123.123.123 localhost 127.0.0.1 ::1
|
||||
```
|
||||
默认回车保存在当前目录下,会有以生成指令第一个域名命名为前缀命名的两个pem文件。
|
||||
|
||||
附带两个文件参数启动即可。
|
||||
````shell
|
||||
python api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem
|
||||
|
||||
./cli.sh api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem
|
||||
````
|
||||
|
||||
此外可以通过前置Nginx转发实现类似效果,可另行查阅相关资料。
|
||||
@ -1,125 +0,0 @@
|
||||
## 在 Anaconda 中使用 pip 安装包无效问题
|
||||
|
||||
## 问题
|
||||
|
||||
最近在跑开源代码的时候遇到的问题:使用 conda 创建虚拟环境并切换到新的虚拟环境后,再使用 pip 来安装包会“无效”。这里的“无效”指的是使用 pip 安装的包不在这个新的环境中。
|
||||
|
||||
------
|
||||
|
||||
## 分析
|
||||
|
||||
1、首先创建一个测试环境 test,`conda create -n test`
|
||||
|
||||
2、激活该测试环境,`conda activate test`
|
||||
|
||||
3、使用 pip 安装 numpy,`pip install numpy`,会发现 numpy 已经存在默认的环境中
|
||||
|
||||
```powershell
|
||||
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
Requirement already satisfied: numpy in c:\programdata\anaconda3\lib\site-packages (1.20.3)
|
||||
```
|
||||
|
||||
4、这时候看一下 pip 的信息,`pip show pip`
|
||||
|
||||
```powershell
|
||||
Name: pip
|
||||
Version: 21.2.4
|
||||
Summary: The PyPA recommended tool for installing Python packages.
|
||||
Home-page: https://pip.pypa.io/
|
||||
Author: The pip developers
|
||||
Author-email: distutils-sig@python.org
|
||||
License: MIT
|
||||
Location: c:\programdata\anaconda3\lib\site-packages
|
||||
Requires:
|
||||
Required-by:
|
||||
```
|
||||
|
||||
5、可以发现当前 pip 是在默认的 conda 环境中。这也就解释了当我们直接使用 pip 安装包时为什么包不在这个新的虚拟环境中,因为使用的 pip 属于默认环境,安装的包要么已经存在,要么直接装到默认环境中去了。
|
||||
|
||||
------
|
||||
|
||||
## 解决
|
||||
|
||||
1、我们可以直接使用 conda 命令安装新的包,但有些时候 conda 可能没有某些包/库,所以还是得用 pip 安装
|
||||
|
||||
2、我们可以先使用 conda 命令为当前虚拟环境安装 pip 包,再使用 pip 安装新的包
|
||||
|
||||
```powershell
|
||||
# 使用 conda 安装 pip 包
|
||||
(test) PS C:\Users\Administrator> conda install pip
|
||||
Collecting package metadata (current_repodata.json): done
|
||||
Solving environment: done
|
||||
....
|
||||
done
|
||||
|
||||
# 显示当前 pip 的信息,发现 pip 在测试环境 test 中
|
||||
(test) PS C:\Users\Administrator> pip show pip
|
||||
Name: pip
|
||||
Version: 21.2.4
|
||||
Summary: The PyPA recommended tool for installing Python packages.
|
||||
Home-page: https://pip.pypa.io/
|
||||
Author: The pip developers
|
||||
Author-email: distutils-sig@python.org
|
||||
License: MIT
|
||||
Location: c:\programdata\anaconda3\envs\test\lib\site-packages
|
||||
Requires:
|
||||
Required-by:
|
||||
|
||||
# 再使用 pip 安装 numpy 包,成功安装
|
||||
(test) PS C:\Users\Administrator> pip install numpy
|
||||
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
Collecting numpy
|
||||
Using cached https://pypi.tuna.tsinghua.edu.cn/packages/4b/23/140ec5a509d992fe39db17200e96c00fd29603c1531ce633ef93dbad5e9e/numpy-1.22.2-cp39-cp39-win_amd64.whl (14.7 MB)
|
||||
Installing collected packages: numpy
|
||||
Successfully installed numpy-1.22.2
|
||||
|
||||
# 使用 pip list 查看当前安装的包,没有问题
|
||||
(test) PS C:\Users\Administrator> pip list
|
||||
Package Version
|
||||
------------ ---------
|
||||
certifi 2021.10.8
|
||||
numpy 1.22.2
|
||||
pip 21.2.4
|
||||
setuptools 58.0.4
|
||||
wheel 0.37.1
|
||||
wincertstore 0.2
|
||||
```
|
||||
|
||||
------
|
||||
|
||||
## 补充
|
||||
|
||||
1、之前没有发现这个问题可能时因为在虚拟环境中安装的包是指定版本的,覆盖了默认环境中的包。其实主要还是观察不仔细:),不然可以发现 `Successfully uninstalled numpy-xxx`【默认版本】 以及 `Successfully installed numpy-1.20.3`【指定版本】
|
||||
|
||||
2、测试时发现如果在新建包的时候指定了 python 版本的话应该是没有这个问题的,猜测时因为会在虚拟环境中安装好 pip ,而我们这里包括 pip 在内啥包也没有装,所以使用的是默认环境的 pip
|
||||
|
||||
3、有个问题,之前我在创建新的虚拟环境时应该指定了 python 版本,但还是使用的默认环境的 pip 包,但是刚在在两台机器上都没有复现成功,于是有了上面的第 2 点
|
||||
|
||||
4、出现了第 3 点的问题后,我当时是使用 `python -m pip install package-name` 解决的,在 pip 前面加上了 python -m。至于为什么,可以参考 [StackOverflow](https://stackoverflow.com/questions/41060382/using-pip-to-install-packages-to-anaconda-environment) 上的回答:
|
||||
|
||||
> 1、如果你有一个非 conda 的 pip 作为你的默认 pip,但是 conda 的 python 是你的默认 python(如下):
|
||||
>
|
||||
> ```shell
|
||||
> >which -a pip
|
||||
> /home/<user>/.local/bin/pip
|
||||
> /home/<user>/.conda/envs/newenv/bin/pip
|
||||
> /usr/bin/pip
|
||||
>
|
||||
> >which -a python
|
||||
> /home/<user>/.conda/envs/newenv/bin/python
|
||||
> /usr/bin/python
|
||||
> ```
|
||||
>
|
||||
> 2、然后,而不是直接调用 `pip install <package>`,你可以在 python 中使用模块标志 -m,以便它使用 anaconda python 进行安装
|
||||
>
|
||||
> ```shell
|
||||
>python -m pip install <package>
|
||||
> ```
|
||||
>
|
||||
> 3、这将把包安装到 anaconda 库目录,而不是与(非anaconda) pip 关联的库目录
|
||||
>
|
||||
> 4、这样做的原因如下:命令 pip 引用了一个特定的 pip 文件 / 快捷方式(which -a pip 会告诉你是哪一个)。类似地,命令 python 引用一个特定的 python 文件(which -a python 会告诉你是哪个)。由于这样或那样的原因,这两个命令可能变得不同步,因此你的“默认” pip 与你的默认 python 位于不同的文件夹中,因此与不同版本的 python 相关联。
|
||||
>
|
||||
> 5、与此相反,python -m pip 构造不使用 pip 命令指向的快捷方式。相反,它要求 python 找到它的pip 版本,并使用该版本安装一个包。
|
||||
|
||||
-
|
||||
@ -1,80 +0,0 @@
|
||||
## 自定义属于自己的Agent
|
||||
### 1. 创建自己的Agent工具
|
||||
+ 开发者在```server/agent```文件中创建一个自己的文件,并将其添加到```tools.py```中。这样就完成了Tools的设定。
|
||||
|
||||
+ 当您创建了一个```custom_agent.py```文件,其中包含一个```work```函数,那么您需要在```tools.py```中添加如下代码:
|
||||
```python
|
||||
from custom_agent import work
|
||||
Tool.from_function(
|
||||
func=work,
|
||||
name="该函数的名字",
|
||||
description=""
|
||||
)
|
||||
```
|
||||
+ 请注意,如果你确定在某一个工程中不会使用到某个工具,可以将其从Tools中移除,降低模型分类错误导致使用错误工具的风险。
|
||||
|
||||
### 2. 修改 custom_template.py文件
|
||||
开发者需要根据自己选择的大模型设定适合该模型的Agent Prompt和自自定义返回格式。
|
||||
在我们的代码中,提供了默认的两种方式,一种是适配于GPT和Qwen的提示词:
|
||||
```python
|
||||
"""
|
||||
Answer the following questions as best you can. You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
Use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
Begin!
|
||||
|
||||
history:
|
||||
{history}
|
||||
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
```
|
||||
|
||||
另一种是适配于GLM-130B的提示词:
|
||||
```python
|
||||
"""
|
||||
尽可能地回答以下问题。你可以使用以下工具:{tools}
|
||||
请按照以下格式进行:
|
||||
Question: 需要你回答的输入问题
|
||||
Thought: 你应该总是思考该做什么
|
||||
Action: 需要使用的工具,应该是[{tool_names}]中的一个
|
||||
Action Input: 传入工具的内容
|
||||
Observation: 行动的结果
|
||||
... (这个Thought/Action/Action Input/Observation可以重复N次)
|
||||
Thought: 我现在知道最后的答案
|
||||
Final Answer: 对原始输入问题的最终答案
|
||||
|
||||
现在开始!
|
||||
|
||||
之前的对话:
|
||||
{history}
|
||||
|
||||
New question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
```
|
||||
|
||||
### 3. 局限性
|
||||
1. 在我们的实验中,小于70B级别的模型,若不经过微调,很难达到较好的效果。因此,我们建议开发者使用大于70B级别的模型进行微调,以达到更好的效果。
|
||||
2. 由于Agent的脆弱性,temperture参数的设置对于模型的效果有很大的影响。我们建议开发者在使用自定义Agent时,对于不同的模型,将其设置成0.1以下,以达到更好的效果。
|
||||
3. 即使使用了大于70B级别的模型,开发者也应该在Prompt上进行深度优化,以让模型能成功的选择工具并完成任务。
|
||||
|
||||
|
||||
### 4. 我们已经支持的Agent
|
||||
我们为开发者编写了三个运用大模型执行的Agent,分别是:
|
||||
1. 翻译工具,实现对输入的任意语言翻译。
|
||||
2. 数学工具,使用LLMMathChain 实现数学计算。
|
||||
3. 天气工具,使用自定义的LLMWetherChain实现天气查询,调用和风天气API。
|
||||
4. 我们支持Langchain支持的Agent工具,在代码中,我们已经提供了Shell和Google Search两个工具的实现。
|
||||
|
Before Width: | Height: | Size: 249 KiB |
|
Before Width: | Height: | Size: 27 KiB |
BIN
img/LLM_success.png
Normal file
|
After Width: | Height: | Size: 148 KiB |
BIN
img/agent_continue.png
Normal file
|
After Width: | Height: | Size: 101 KiB |
BIN
img/agent_success.png
Normal file
|
After Width: | Height: | Size: 84 KiB |
|
Before Width: | Height: | Size: 204 KiB |
BIN
img/fastapi_docs_026.png
Normal file
|
After Width: | Height: | Size: 75 KiB |
BIN
img/init_knowledge_base.jpg
Normal file
|
After Width: | Height: | Size: 75 KiB |
BIN
img/knowledge_base_success.jpg
Normal file
|
After Width: | Height: | Size: 114 KiB |
|
Before Width: | Height: | Size: 27 KiB After Width: | Height: | Size: 27 KiB |
9
img/partners/autodl.svg
Normal file
|
After Width: | Height: | Size: 123 KiB |
9
img/partners/aws.svg
Normal file
|
After Width: | Height: | Size: 42 KiB |
9
img/partners/chatglm.svg
Normal file
|
After Width: | Height: | Size: 6.3 KiB |
9
img/partners/zhenfund.svg
Normal file
@ -0,0 +1,9 @@
|
||||
<svg width="654" height="213" viewBox="0 0 654 213" fill="none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<rect x="654" width="213" height="654" transform="rotate(90 654 0)" fill="url(#pattern0)"/>
|
||||
<defs>
|
||||
<pattern id="pattern0" patternContentUnits="objectBoundingBox" width="1" height="1">
|
||||
<use xlink:href="#image0_237_57" transform="matrix(0.0204695 0 0 0.00666667 -0.00150228 0)"/>
|
||||
</pattern>
|
||||
<image id="image0_237_57" width="49" height="150" xlink:href=""/>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 108 KiB |
|
Before Width: | Height: | Size: 188 KiB |
|
Before Width: | Height: | Size: 240 KiB |
|
Before Width: | Height: | Size: 237 KiB |
|
Before Width: | Height: | Size: 170 KiB |
@ -1,3 +1,5 @@
|
||||
import sys
|
||||
sys.path.append(".")
|
||||
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
import nltk
|
||||
|
||||
@ -1,31 +1,33 @@
|
||||
langchain>=0.0.302
|
||||
fschat[model_worker]==0.2.29
|
||||
openai
|
||||
langchain>=0.0.319
|
||||
langchain-experimental>=0.0.30
|
||||
fschat[model_worker]==0.2.31
|
||||
xformers>=0.0.22.post4
|
||||
openai>=0.28.1
|
||||
sentence_transformers
|
||||
transformers==4.33.3
|
||||
torch>=2.0.1
|
||||
transformers>=4.34
|
||||
torch>=2.0.1 # 推荐2.1
|
||||
torchvision
|
||||
torchaudio
|
||||
fastapi>=0.103.1
|
||||
nltk~=3.8.1
|
||||
fastapi>=0.104
|
||||
nltk>=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
starlette~=0.27.0
|
||||
pydantic~=1.10.11
|
||||
unstructured[all-docs]>=0.10.4
|
||||
unstructured[all-docs]>=0.10.12
|
||||
python-magic-bin; sys_platform == 'win32'
|
||||
SQLAlchemy==2.0.19
|
||||
faiss-cpu
|
||||
accelerate
|
||||
spacy
|
||||
PyMuPDF==1.22.5
|
||||
rapidocr_onnxruntime>=1.3.2
|
||||
PyMuPDF
|
||||
rapidocr_onnxruntime
|
||||
|
||||
requests
|
||||
pathlib
|
||||
pytest
|
||||
scikit-learn
|
||||
numexpr
|
||||
vllm==0.1.7; sys_platform == "linux"
|
||||
vllm>=0.2.0; sys_platform == "linux"
|
||||
# online api libs
|
||||
# zhipuai
|
||||
# dashscope>=1.10.0 # qwen
|
||||
@ -42,7 +44,7 @@ pandas~=2.0.3
|
||||
streamlit>=1.26.0
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox>=1.1.9
|
||||
streamlit-chatbox==1.1.10
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
watchdog
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
langchain>=0.0.302
|
||||
fschat[model_worker]==0.2.29
|
||||
openai
|
||||
sentence_transformers
|
||||
transformers>=4.33.0
|
||||
torch >=2.0.1
|
||||
langchain>=0.0.319
|
||||
langchain-experimental>=0.0.30
|
||||
fschat[model_worker]==0.2.31
|
||||
xformers>=0.0.22.post4
|
||||
openai>=0.28.1
|
||||
sentence_transformers>=2.2.2
|
||||
transformers>=4.34
|
||||
torch>=2.1
|
||||
torchvision
|
||||
torchaudio
|
||||
fastapi>=0.103.1
|
||||
fastapi>=0.104
|
||||
nltk~=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
starlette~=0.27.0
|
||||
@ -24,7 +26,8 @@ pathlib
|
||||
pytest
|
||||
scikit-learn
|
||||
numexpr
|
||||
vllm==0.1.7; sys_platform == "linux"
|
||||
|
||||
vllm>=0.2.0; sys_platform == "linux"
|
||||
|
||||
|
||||
# online api libs
|
||||
@ -32,6 +35,7 @@ vllm==0.1.7; sys_platform == "linux"
|
||||
# dashscope>=1.10.0 # qwen
|
||||
# qianfan
|
||||
# volcengine>=1.0.106 # fangzhou
|
||||
# duckduckgo-searchd #duckduckgo搜索
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
numpy~=1.24.4
|
||||
pandas~=2.0.3
|
||||
streamlit>=1.26.0
|
||||
streamlit>=1.27.2
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox>=1.1.9
|
||||
streamlit-antd-components>=0.2.3
|
||||
streamlit-chatbox==1.1.10
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
nltk
|
||||
httpx>=0.25.0
|
||||
nltk>=3.8.1
|
||||
watchdog
|
||||
websockets
|
||||
|
||||
4
server/agent/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .model_contain import *
|
||||
from .callbacks import *
|
||||
from .custom_template import *
|
||||
from .tools import *
|
||||
@ -20,7 +20,7 @@ class Status:
|
||||
agent_action: int = 4
|
||||
agent_finish: int = 5
|
||||
error: int = 6
|
||||
make_tool: int = 7
|
||||
tool_finish: int = 7
|
||||
|
||||
|
||||
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
@ -34,6 +34,15 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
|
||||
# 对于截断不能自理的大模型,我来帮他截断
|
||||
stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
|
||||
for stop_word in stop_words:
|
||||
index = input_str.find(stop_word)
|
||||
if index != -1:
|
||||
input_str = input_str[:index]
|
||||
break
|
||||
|
||||
self.cur_tool = {
|
||||
"tool_name": serialized["name"],
|
||||
"input_str": input_str,
|
||||
@ -44,13 +53,14 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
"final_answer": "",
|
||||
"error": "",
|
||||
}
|
||||
# print("\nInput Str:",self.cur_tool["input_str"])
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.out = True
|
||||
self.out = True ## 重置输出
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_finish,
|
||||
status=Status.tool_finish,
|
||||
output_str=output.replace("Answer:", ""),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
@ -64,20 +74,22 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
if token:
|
||||
if "Action" in token:
|
||||
self.out = False
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token="\n\n",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
if self.out:
|
||||
self.cur_tool.update(
|
||||
if "Action" in token: ## 减少重复输出
|
||||
before_action = token.split("Action")[0]
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=before_action + "\n",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
self.out = False
|
||||
|
||||
if token and self.out:
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=token,
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
@ -85,17 +97,31 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.out = True
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.complete,
|
||||
status=Status.start,
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.complete,
|
||||
llm_token="\n",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||
self.out = True
|
||||
self.cur_tool.update(
|
||||
status=Status.error,
|
||||
error=str(error),
|
||||
@ -107,4 +133,10 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# 返回最终答案
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_finish,
|
||||
final_answer=finish.return_values["output"],
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
self.cur_tool = {}
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from langchain.agents import Tool, AgentOutputParser
|
||||
from langchain.prompts import StringPromptTemplate
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
import re
|
||||
|
||||
from server.agent import model_container
|
||||
class CustomPromptTemplate(StringPromptTemplate):
|
||||
# The template to use
|
||||
template: str
|
||||
@ -19,40 +18,74 @@ class CustomPromptTemplate(StringPromptTemplate):
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\nObservation: {observation}\nThought: "
|
||||
# Set the agent_scratchpad variable to that value
|
||||
# Set the agent_scratchpad variable to that value
|
||||
kwargs["agent_scratchpad"] = thoughts
|
||||
# Create a tools variable from the list of tools provided
|
||||
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
|
||||
# Create a list of tool names for the tools provided
|
||||
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
||||
# Return the formatted templatepr
|
||||
# print( self.template.format(**kwargs), end="\n\n")
|
||||
return self.template.format(**kwargs)
|
||||
class CustomOutputParser(AgentOutputParser):
|
||||
|
||||
def parse(self, llm_output: str) -> AgentFinish | AgentAction | str:
|
||||
|
||||
class CustomOutputParser(AgentOutputParser):
|
||||
begin: bool = False
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.begin = True
|
||||
|
||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
||||
# Check if agent should finish
|
||||
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
||||
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
|
||||
self.begin = False
|
||||
stop_words = ["Observation:"]
|
||||
min_index = len(llm_output)
|
||||
for stop_word in stop_words:
|
||||
index = llm_output.find(stop_word)
|
||||
if index != -1 and index < min_index:
|
||||
min_index = index
|
||||
llm_output = llm_output[:min_index]
|
||||
|
||||
if "Final Answer:" in llm_output:
|
||||
self.begin = True
|
||||
return AgentFinish(
|
||||
# Return values is generally always a dictionary with a single `output` key
|
||||
# It is not recommended to try anything else at the moment :)
|
||||
return_values={"output": llm_output.replace("Final Answer:", "").strip()},
|
||||
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
|
||||
log=llm_output,
|
||||
)
|
||||
|
||||
# Parse out the action and action input
|
||||
regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
match = re.search(regex, llm_output, re.DOTALL)
|
||||
if not match:
|
||||
parts = llm_output.split("Action:")
|
||||
if len(parts) < 2:
|
||||
return AgentFinish(
|
||||
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
||||
log=llm_output,
|
||||
)
|
||||
action = match.group(1).strip()
|
||||
action_input = match.group(2)
|
||||
|
||||
action = parts[1].split("Action Input:")[0].strip()
|
||||
action_input = parts[1].split("Action Input:")[1].strip()
|
||||
|
||||
# 原来的正则化检查方式
|
||||
# regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
# print("llm_output",llm_output)
|
||||
# match = re.search(regex, llm_output, re.DOTALL)
|
||||
# print("match",match)
|
||||
# if not match:
|
||||
# return AgentFinish(
|
||||
# return_values={"output": f"调用agent失败: `{llm_output}`"},
|
||||
# log=llm_output,
|
||||
# )
|
||||
# action = match.group(1).strip()
|
||||
# action_input = match.group(2)
|
||||
|
||||
# Return the action and action input
|
||||
|
||||
try:
|
||||
ans = AgentAction(
|
||||
tool=action,
|
||||
tool_input=action_input.strip(" ").strip('"'),
|
||||
log=llm_output
|
||||
tool=action,
|
||||
tool_input=action_input.strip(" ").strip('"'),
|
||||
log=llm_output
|
||||
)
|
||||
return ans
|
||||
except:
|
||||
@ -60,6 +93,3 @@ class CustomOutputParser(AgentOutputParser):
|
||||
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
||||
log=llm_output,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,8 +0,0 @@
|
||||
import os
|
||||
os.environ["GOOGLE_CSE_ID"] = ""
|
||||
os.environ["GOOGLE_API_KEY"] = ""
|
||||
|
||||
from langchain.tools import GoogleSearchResults
|
||||
def google_search(query: str):
|
||||
tool = GoogleSearchResults()
|
||||
return tool.run(tool_input=query)
|
||||
8
server/agent/model_contain.py
Normal file
@ -0,0 +1,8 @@
|
||||
|
||||
## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍
|
||||
class ModelContainer:
|
||||
def __init__(self):
|
||||
self.MODEL = None
|
||||
self.DATABASE = None
|
||||
|
||||
model_container = ModelContainer()
|
||||
@ -1,40 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from server.agent.math import calculate
|
||||
from server.agent.translator import translate
|
||||
from server.agent.weather import weathercheck
|
||||
from server.agent.shell import shell
|
||||
from server.agent.google_search import google_search
|
||||
from langchain.agents import Tool
|
||||
|
||||
tools = [
|
||||
Tool.from_function(
|
||||
func=calculate,
|
||||
name="计算器工具",
|
||||
description="进行简单的数学运算"
|
||||
),
|
||||
Tool.from_function(
|
||||
func=translate,
|
||||
name="翻译工具",
|
||||
description="翻译各种语言"
|
||||
),
|
||||
Tool.from_function(
|
||||
func=weathercheck,
|
||||
name="天气查询工具",
|
||||
description="查询天气",
|
||||
),
|
||||
Tool.from_function(
|
||||
func=shell,
|
||||
name="shell工具",
|
||||
description="使用命令行工具输出",
|
||||
),
|
||||
Tool.from_function(
|
||||
func=google_search,
|
||||
name="谷歌搜索工具",
|
||||
description="使用谷歌搜索",
|
||||
)
|
||||
]
|
||||
tool_names = [tool.name for tool in tools]
|
||||
10
server/agent/tools/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
## 导入所有的工具类
|
||||
from .search_knowledge_simple import knowledge_search_simple
|
||||
from .search_all_knowledge_once import knowledge_search_once
|
||||
from .search_all_knowledge_more import knowledge_search_more
|
||||
from .calculate import calculate
|
||||
from .translator import translate
|
||||
from .weather import weathercheck
|
||||
from .shell import shell
|
||||
from .search_internet import search_internet
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
## 单独运行的时候需要添加
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.chains import LLMMathChain
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.callbacks.manager import CallbackManagerForToolRun
|
||||
from server.agent import model_container
|
||||
|
||||
_PROMPT_TEMPLATE = """将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
|
||||
_PROMPT_TEMPLATE = """
|
||||
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
|
||||
问题: ${{包含数学问题的问题。}}
|
||||
```text
|
||||
${{解决问题的单行数学表达式}}
|
||||
@ -60,11 +63,12 @@ PROMPT = PromptTemplate(
|
||||
|
||||
|
||||
def calculate(query: str):
|
||||
model = get_ChatOpenAI(
|
||||
streaming=False,
|
||||
model_name=LLM_MODEL,
|
||||
temperature=TEMPERATURE,
|
||||
)
|
||||
model = model_container.MODEL
|
||||
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_math.run(query)
|
||||
return ans
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = calculate("2的三次方")
|
||||
print("答案:",result)
|
||||
|
||||
296
server/agent/tools/search_all_knowledge_more.py
Normal file
@ -0,0 +1,296 @@
|
||||
## 单独运行的时候需要添加
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict
|
||||
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from typing import List, Any, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=None,
|
||||
prompt_name="default",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
|
||||
async def search_knowledge_multiple(queries) -> List[str]:
|
||||
# queries 应该是一个包含多个 (database, query) 元组的列表
|
||||
tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
|
||||
combined_results = []
|
||||
for (database, _), result in zip(queries, results):
|
||||
message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
|
||||
combined_results.append(message)
|
||||
|
||||
return combined_results
|
||||
|
||||
|
||||
def search_knowledge(queries) -> str:
|
||||
responses = asyncio.run(search_knowledge_multiple(queries))
|
||||
# 输出每个整合的查询结果
|
||||
contents = ""
|
||||
for response in responses:
|
||||
contents += response + "\n\n"
|
||||
return contents
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
|
||||
|
||||
对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
|
||||
|
||||
例子:
|
||||
|
||||
robotic,机器人男女比例是多少
|
||||
bigdata,大数据的就业情况如何
|
||||
|
||||
|
||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
|
||||
|
||||
{database_names}
|
||||
|
||||
你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||
|
||||
|
||||
Question: ${{用户的问题}}
|
||||
|
||||
```text
|
||||
${{知识库名称,查询问题,不要带有任何除了,之外的符号}}
|
||||
|
||||
```output
|
||||
数据库查询的结果
|
||||
|
||||
|
||||
|
||||
这是一个完整的问题拆分和提问的例子:
|
||||
|
||||
|
||||
问题: 分别对比机器人和大数据专业的就业情况并告诉我哪儿专业的就业情况更好?
|
||||
|
||||
```text
|
||||
robotic,机器人专业的就业情况
|
||||
bigdata,大数据专业的就业情况
|
||||
|
||||
|
||||
|
||||
现在,我们开始作答
|
||||
问题: {question}
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question", "database_names"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class LLMKnowledgeChain(LLMChain):
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
database_names: Dict[str, str] = None
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _evaluate_expression(self, queries) -> str:
|
||||
try:
|
||||
output = search_knowledge(queries)
|
||||
except Exception as e:
|
||||
output = "输入的信息有误或不存在知识库,错误信息如下:\n"
|
||||
return output + str(e)
|
||||
return output
|
||||
|
||||
def _process_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
|
||||
llm_output = llm_output.strip()
|
||||
# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1).strip()
|
||||
cleaned_input_str = (expression.replace("\"", "").replace("“", "").
|
||||
replace("”", "").replace("```", "").strip())
|
||||
lines = cleaned_input_str.split("\n")
|
||||
# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
|
||||
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
|
||||
output = self._evaluate_expression(queries)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
return {self.output_key: f"输入的格式不对:\n {llm_output}"}
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
|
||||
expression = text_match.group(1).strip()
|
||||
cleaned_input_str = (
|
||||
expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip())
|
||||
lines = cleaned_input_str.split("\n")
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
|
||||
verbose=self.verbose)
|
||||
|
||||
output = self._evaluate_expression(queries)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
self.database_names = model_container.DATABASE
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = self.llm_chain.predict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return self._process_llm_result(llm_output, _run_manager)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
self.database_names = model_container.DATABASE
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_knowledge_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
):
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
def knowledge_search_more(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_knowledge.run(query)
|
||||
return ans
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = knowledge_search_more("机器人和大数据在代码教学上有什么区别")
|
||||
print(result)
|
||||
|
||||
# 这是一个正常的切割
|
||||
# queries = [
|
||||
# ("bigdata", "大数据专业的男女比例"),
|
||||
# ("robotic", "机器人专业的优势")
|
||||
# ]
|
||||
# result = search_knowledge(queries)
|
||||
# print(result)
|
||||
234
server/agent/tools/search_all_knowledge_once.py
Normal file
@ -0,0 +1,234 @@
|
||||
## 单独运行的时候需要添加
|
||||
# import sys
|
||||
# import os
|
||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from typing import List, Any, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
||||
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str):
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=None,
|
||||
prompt_name="knowledge_base_chat",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
|
||||
Question: ${{用户的问题}}
|
||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
|
||||
|
||||
{database_names}
|
||||
|
||||
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||
```text
|
||||
${{知识库的名称}}
|
||||
```
|
||||
```output
|
||||
数据库查询的结果
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
现在,这是我的问题:
|
||||
问题: {question}
|
||||
|
||||
"""
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question", "database_names"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class LLMKnowledgeChain(LLMChain):
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
database_names: Dict[str, str] = model_container.DATABASE
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _evaluate_expression(self, dataset, query) -> str:
|
||||
try:
|
||||
output = asyncio.run(search_knowledge_base_iter(dataset, query))
|
||||
except Exception as e:
|
||||
output = "输入的信息有误或不存在知识库"
|
||||
return output
|
||||
return output
|
||||
|
||||
def _process_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
llm_input: str,
|
||||
run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
database = text_match.group(1).strip()
|
||||
output = self._evaluate_expression(database, llm_input)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
return {self.output_key: f"输入的格式不对: {llm_output}"}
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = self.llm_chain.predict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_knowledge_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
):
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
def knowledge_search_once(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_knowledge.run(query)
|
||||
return ans
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = knowledge_search_once("大数据的男女比例")
|
||||
print(result)
|
||||
39
server/agent/tools/search_internet.py
Normal file
@ -0,0 +1,39 @@
|
||||
## 单独运行的时候需要添加
|
||||
# import sys
|
||||
# import os
|
||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||
|
||||
import json
|
||||
from server.chat import search_engine_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
|
||||
async def search_engine_iter(query: str):
|
||||
response = await search_engine_chat(query=query,
|
||||
search_engine_name="bing", # 这里切换搜索引擎
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01, # Agent 搜索互联网的时候,温度设置为0.01
|
||||
history=[],
|
||||
top_k = VECTOR_SEARCH_TOP_K,
|
||||
max_tokens= None, # Agent 搜索互联网的时候,max_tokens设置为None
|
||||
prompt_name = "default",
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents = data["answer"]
|
||||
docs = data["docs"]
|
||||
|
||||
return contents
|
||||
|
||||
def search_internet(query: str):
|
||||
|
||||
return asyncio.run(search_engine_iter(query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_internet("今天星期几")
|
||||
print("答案:",result)
|
||||
38
server/agent/tools/search_knowledge_simple.py
Normal file
@ -0,0 +1,38 @@
|
||||
## 最简单的版本,只支持固定的知识库
|
||||
|
||||
# ## 单独运行的时候需要添加
|
||||
# import sys
|
||||
# import os
|
||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
||||
import json
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
prompt_name="knowledge_base_chat",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents = data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
def knowledge_search_simple(query: str):
|
||||
return asyncio.run(search_knowledge_base_iter(query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = knowledge_search_simple("大数据男女比例")
|
||||
print("答案:",result)
|
||||
39
server/agent/tools/translator.py
Normal file
@ -0,0 +1,39 @@
|
||||
## 单独运行的时候需要添加
|
||||
# import sys
|
||||
# import os
|
||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.chains import LLMChain
|
||||
from server.agent import model_container
|
||||
|
||||
_PROMPT_TEMPLATE = '''
|
||||
# 指令
|
||||
接下来,作为一个专业的翻译专家,当我给出句子或段落时,你将提供通顺且具有可读性的对应语言的翻译。注意:
|
||||
1. 确保翻译结果流畅且易于理解
|
||||
2. 无论提供的是陈述句或疑问句,只进行翻译
|
||||
3. 不添加与原文无关的内容
|
||||
|
||||
问题: ${{用户需要翻译的原文和目标语言}}
|
||||
答案: 你翻译结果
|
||||
|
||||
现在,这是我的问题:
|
||||
问题: {question}
|
||||
|
||||
'''
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
def translate(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_translate = LLMChain(llm=model, prompt=PROMPT)
|
||||
ans = llm_translate.run(query)
|
||||
return ans
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = translate("Can Love remember the question and the answer? 这句话如何诗意的翻译成中文")
|
||||
print("答案:",result)
|
||||
@ -1,14 +1,11 @@
|
||||
## 使用和风天气API查询天气
|
||||
## 使用和风天气API查询天气,这个模型仅仅对免费的API进行了适配
|
||||
## 这个模型的提示词非常复杂,我们推荐使用GPT4模型进行运行
|
||||
from __future__ import annotations
|
||||
|
||||
## 单独运行的时候需要添加
|
||||
import sys
|
||||
import os
|
||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
|
||||
from server.utils import get_ChatOpenAI
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
import re
|
||||
import warnings
|
||||
@ -25,10 +22,72 @@ from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
import requests
|
||||
from typing import List, Any, Optional
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||
from datetime import datetime
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.agent import model_container
|
||||
|
||||
## 使用和风天气API查询天气
|
||||
KEY = ""
|
||||
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
|
||||
#key长这样,这里提供了示例的key,这个key没法使用,你需要自己去注册和风天气的账号,然后在这里填入你的key
|
||||
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
用户会提出一个关于天气的问题,你的目标是拆分出用户问题中的区,市 并按照我提供的工具回答。
|
||||
例如 用户提出的问题是: 上海浦东未来1小时天气情况?
|
||||
则 提取的市和区是: 上海 浦东
|
||||
如果用户提出的问题是: 上海未来1小时天气情况?
|
||||
则 提取的市和区是: 上海 None
|
||||
请注意以下内容:
|
||||
1. 如果你没有找到区的内容,则一定要使用 None 替代,否则程序无法运行
|
||||
2. 如果用户没有指定市 则直接返回缺少信息
|
||||
|
||||
问题: ${{用户的问题}}
|
||||
|
||||
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||
```text
|
||||
|
||||
${{拆分的市和区,中间用空格隔开}}
|
||||
```
|
||||
... weathercheck(市 区)...
|
||||
```output
|
||||
|
||||
${{提取后的答案}}
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
|
||||
|
||||
这是一个例子:
|
||||
问题: 上海浦东未来1小时天气情况?
|
||||
|
||||
|
||||
```text
|
||||
上海 浦东
|
||||
```
|
||||
...weathercheck(上海 浦东)...
|
||||
|
||||
```output
|
||||
预报时间: 1小时后
|
||||
具体时间: 今天 18:00
|
||||
温度: 24°C
|
||||
天气: 多云
|
||||
风向: 西南风
|
||||
风速: 7级
|
||||
湿度: 88%
|
||||
降水概率: 16%
|
||||
|
||||
Answer: 上海浦东一小时后的天气是多云。
|
||||
|
||||
现在,这是我的问题:
|
||||
|
||||
问题: {question}
|
||||
"""
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
def get_city_info(location, adm, key):
|
||||
base_url = 'https://geoapi.qweather.com/v2/city/lookup?'
|
||||
@ -38,12 +97,9 @@ def get_city_info(location, adm, key):
|
||||
return data
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def format_weather_data(data):
|
||||
def format_weather_data(data,place):
|
||||
hourly_forecast = data['hourly']
|
||||
formatted_data = ''
|
||||
formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n"
|
||||
for forecast in hourly_forecast:
|
||||
# 将预报时间转换为datetime对象
|
||||
forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z')
|
||||
@ -71,12 +127,11 @@ def format_weather_data(data):
|
||||
elif hours_diff >= 24:
|
||||
# 如果超过24小时,转换为天数
|
||||
days_diff = hours_diff // 24
|
||||
hours_diff_str = str(int(days_diff)) + '天后'
|
||||
hours_diff_str = str(int(days_diff)) + '天'
|
||||
else:
|
||||
hours_diff_str = str(int(hours_diff)) + '小时后'
|
||||
hours_diff_str = str(int(hours_diff)) + '小时'
|
||||
# 将预报时间和当前时间的差值添加到输出中
|
||||
formatted_data += '预报时间: ' + hours_diff_str + '\n'
|
||||
formatted_data += '具体时间: ' + forecast_time_str + '\n'
|
||||
formatted_data += '预报时间: ' + forecast_time_str + ' 距离现在有: ' + hours_diff_str + '\n'
|
||||
formatted_data += '温度: ' + forecast['temp'] + '°C\n'
|
||||
formatted_data += '天气: ' + forecast['text'] + '\n'
|
||||
formatted_data += '风向: ' + forecast['windDir'] + '\n'
|
||||
@ -84,53 +139,54 @@ def format_weather_data(data):
|
||||
formatted_data += '湿度: ' + forecast['humidity'] + '%\n'
|
||||
formatted_data += '降水概率: ' + forecast['pop'] + '%\n'
|
||||
# formatted_data += '降水量: ' + forecast['precip'] + 'mm\n'
|
||||
formatted_data += '\n\n'
|
||||
formatted_data += '\n'
|
||||
return formatted_data
|
||||
|
||||
|
||||
def get_weather(key, location_id, time: str = "24"):
|
||||
if time:
|
||||
url = "https://devapi.qweather.com/v7/weather/" + time + "h?"
|
||||
else:
|
||||
time = "3" # 免费订阅只能查看3天的天气
|
||||
url = "https://devapi.qweather.com/v7/weather/" + time + "d?"
|
||||
def get_weather(key, location_id,place):
|
||||
url = "https://devapi.qweather.com/v7/weather/24h?"
|
||||
params = {
|
||||
'location': location_id,
|
||||
'key': key,
|
||||
}
|
||||
response = requests.get(url, params=params)
|
||||
data = response.json()
|
||||
return format_weather_data(data)
|
||||
return format_weather_data(data,place)
|
||||
|
||||
|
||||
def split_query(query):
|
||||
parts = query.split()
|
||||
location = parts[0] if parts[0] != 'None' else parts[1]
|
||||
adm = parts[1]
|
||||
time = parts[2]
|
||||
return location, adm, time
|
||||
adm = parts[0]
|
||||
location = parts[1] if parts[1] != 'None' else adm
|
||||
return location, adm
|
||||
|
||||
|
||||
def weather(query):
|
||||
location, adm, time = split_query(query)
|
||||
location, adm= split_query(query)
|
||||
key = KEY
|
||||
if time != "None" and int(time) > 24:
|
||||
return "只能查看24小时内的天气,无法回答"
|
||||
if time == "None":
|
||||
time = "24" # 免费的版本只能24小时内的天气
|
||||
if key == "":
|
||||
return "请先在代码中填入和风天气API Key"
|
||||
city_info = get_city_info(location=location, adm=adm, key=key)
|
||||
location_id = city_info['location'][0]['id']
|
||||
weather_data = get_weather(key=key, location_id=location_id, time=time)
|
||||
return weather_data
|
||||
|
||||
try:
|
||||
city_info = get_city_info(location=location, adm=adm, key=key)
|
||||
location_id = city_info['location'][0]['id']
|
||||
place = adm + "市" + location + "区"
|
||||
|
||||
weather_data = get_weather(key=key, location_id=location_id,place=place)
|
||||
return weather_data + "以上是查询到的天气信息,请你查收\n"
|
||||
except KeyError:
|
||||
try:
|
||||
city_info = get_city_info(location=adm, adm=adm, key=key)
|
||||
location_id = city_info['location'][0]['id']
|
||||
place = adm + "市"
|
||||
weather_data = get_weather(key=key, location_id=location_id,place=place)
|
||||
return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n"
|
||||
except KeyError:
|
||||
return "输入的地区不存在,无法提供天气预报"
|
||||
class LLMWeatherChain(Chain):
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
@ -175,8 +231,6 @@ class LLMWeatherChain(Chain):
|
||||
output = weather(expression)
|
||||
except Exception as e:
|
||||
output = "输入的信息有误,请再次尝试"
|
||||
# raise ValueError(f"错误: {expression},输入的信息不对")
|
||||
|
||||
return output
|
||||
|
||||
def _process_llm_result(
|
||||
@ -198,7 +252,7 @@ class LLMWeatherChain(Chain):
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"}
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
@ -209,6 +263,7 @@ class LLMWeatherChain(Chain):
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
@ -259,107 +314,19 @@ class LLMWeatherChain(Chain):
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMWeatherChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
_PROMPT_TEMPLATE = """用户将会向您咨询天气问题,您不需要自己回答天气问题,而是将用户提问的信息提取出来区,市和时间三个元素后使用我为你编写好的工具进行查询并返回结果,格式为 区+市+时间 每个元素用空格隔开。如果缺少信息,则用 None 代替。
|
||||
问题: ${{用户的问题}}
|
||||
|
||||
```text
|
||||
|
||||
${{拆分的区,市和时间}}
|
||||
```
|
||||
|
||||
... weather(提取后的关键字,用空格隔开)...
|
||||
```output
|
||||
|
||||
${{提取后的答案}}
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
这是两个例子:
|
||||
问题: 上海浦东未来1小时天气情况?
|
||||
|
||||
```text
|
||||
浦东 上海 1
|
||||
```
|
||||
...weather(浦东 上海 1)...
|
||||
|
||||
```output
|
||||
|
||||
预报时间: 1小时后
|
||||
具体时间: 今天 18:00
|
||||
温度: 24°C
|
||||
天气: 多云
|
||||
风向: 西南风
|
||||
风速: 7级
|
||||
湿度: 88%
|
||||
降水概率: 16%
|
||||
|
||||
Answer:
|
||||
预报时间: 1小时后
|
||||
具体时间: 今天 18:00
|
||||
温度: 24°C
|
||||
天气: 多云
|
||||
风向: 西南风
|
||||
风速: 7级
|
||||
湿度: 88%
|
||||
降水概率: 16%
|
||||
|
||||
问题: 北京市朝阳区未来24小时天气如何?
|
||||
```text
|
||||
|
||||
朝阳 北京 24
|
||||
```
|
||||
...weather(朝阳 北京 24)...
|
||||
```output
|
||||
预报时间: 23小时后
|
||||
具体时间: 明天 17:00
|
||||
温度: 26°C
|
||||
天气: 霾
|
||||
风向: 西南风
|
||||
风速: 11级
|
||||
湿度: 65%
|
||||
降水概率: 20%
|
||||
Answer:
|
||||
预报时间: 23小时后
|
||||
具体时间: 明天 17:00
|
||||
温度: 26°C
|
||||
天气: 霾
|
||||
风向: 西南风
|
||||
风速: 11级
|
||||
湿度: 65%
|
||||
降水概率: 20%
|
||||
|
||||
现在,这是我的问题:
|
||||
问题: {question}
|
||||
"""
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
def weathercheck(query: str):
|
||||
model = get_ChatOpenAI(
|
||||
streaming=False,
|
||||
model_name=LLM_MODEL,
|
||||
temperature=TEMPERATURE,
|
||||
)
|
||||
model = model_container.MODEL
|
||||
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_weather.run(query)
|
||||
return ans
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
## 检测api是否能正确返回
|
||||
query = "上海浦东未来1小时天气情况"
|
||||
# ans = weathercheck(query)
|
||||
ans = weather("浦东 上海 1")
|
||||
print(ans)
|
||||
result = weathercheck("苏州姑苏区今晚热不热?")
|
||||
36
server/agent/tools_select.py
Normal file
@ -0,0 +1,36 @@
|
||||
from langchain.tools import Tool
|
||||
from server.agent.tools import *
|
||||
tools = [
|
||||
Tool.from_function(
|
||||
func=calculate,
|
||||
name="计算器工具",
|
||||
description="进行简单的数学运算"
|
||||
),
|
||||
Tool.from_function(
|
||||
func=translate,
|
||||
name="翻译工具",
|
||||
description="如果你无法访问互联网,并且需要翻译各种语言,应该使用这个工具"
|
||||
),
|
||||
Tool.from_function(
|
||||
func=weathercheck,
|
||||
name="天气查询工具",
|
||||
description="无需访问互联网,使用这个工具查询中国各地未来24小时的天气",
|
||||
),
|
||||
Tool.from_function(
|
||||
func=shell,
|
||||
name="shell工具",
|
||||
description="使用命令行工具输出",
|
||||
),
|
||||
Tool.from_function(
|
||||
func=knowledge_search_more,
|
||||
name="知识库查询工具",
|
||||
description="优先访问知识库来获取答案",
|
||||
),
|
||||
Tool.from_function(
|
||||
func=search_internet,
|
||||
name="互联网查询工具",
|
||||
description="如果你无法访问互联网,这个工具可以帮助你访问Bing互联网来解答问题",
|
||||
),
|
||||
]
|
||||
|
||||
tool_names = [tool.name for tool in tools]
|
||||
@ -1,55 +0,0 @@
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.chains import LLMChain
|
||||
import sys
|
||||
import os
|
||||
|
||||
from server.utils import get_ChatOpenAI
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from configs.model_config import LLM_MODEL,TEMPERATURE
|
||||
|
||||
_PROMPT_TEMPLATE = '''
|
||||
# 指令
|
||||
接下来,作为一个专业的翻译专家,当我给出句子或段落时,你将提供通顺且具有可读性的对应语言的翻译。注意:
|
||||
1. 确保翻译结果流畅且易于理解
|
||||
2. 无论提供的是陈述句或疑问句,只进行翻译
|
||||
3. 不添加与原文无关的内容
|
||||
|
||||
原文: ${{用户需要翻译的原文和目标语言}}
|
||||
{question}
|
||||
```output
|
||||
${{翻译结果}}
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
以下是两个例子
|
||||
问题: 翻译13成英语
|
||||
```text
|
||||
13 英语
|
||||
```output
|
||||
thirteen
|
||||
以下是两个例子
|
||||
问题: 翻译 我爱你 成法语
|
||||
```text
|
||||
13 法语
|
||||
```output
|
||||
Je t'aime.
|
||||
'''
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
def translate(query: str):
|
||||
model = get_ChatOpenAI(
|
||||
streaming=False,
|
||||
model_name=LLM_MODEL,
|
||||
temperature=TEMPERATURE,
|
||||
)
|
||||
llm_translate = LLMChain(llm=model, prompt=PROMPT)
|
||||
ans = llm_translate.run(query)
|
||||
|
||||
return ans
|
||||
@ -16,9 +16,11 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
search_docs, DocumentWithScore, update_info)
|
||||
from server.llm_api import (list_running_models, list_config_models,
|
||||
change_llm_model, stop_llm_model,
|
||||
get_model_config, list_search_engines)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs
|
||||
from typing import List
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
@ -113,6 +115,11 @@ def create_app():
|
||||
summary="删除知识库内指定文件"
|
||||
)(delete_docs)
|
||||
|
||||
app.post("/knowledge_base/update_info",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="更新知识库介绍"
|
||||
)(update_info)
|
||||
app.post("/knowledge_base/update_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
@ -139,6 +146,11 @@ def create_app():
|
||||
summary="列出configs已配置的模型",
|
||||
)(list_config_models)
|
||||
|
||||
app.post("/llm_model/get_model_config",
|
||||
tags=["LLM Model Management"],
|
||||
summary="获取模型配置(合并后)",
|
||||
)(get_model_config)
|
||||
|
||||
app.post("/llm_model/stop",
|
||||
tags=["LLM Model Management"],
|
||||
summary="停止指定的LLM模型(Model Worker)",
|
||||
@ -149,6 +161,17 @@ def create_app():
|
||||
summary="切换指定的LLM模型(Model Worker)",
|
||||
)(change_llm_model)
|
||||
|
||||
# 服务器相关接口
|
||||
app.post("/server/configs",
|
||||
tags=["Server State"],
|
||||
summary="获取服务器原始配置信息",
|
||||
)(get_server_configs)
|
||||
|
||||
app.post("/server/list_search_engines",
|
||||
tags=["Server State"],
|
||||
summary="获取服务器支持的搜索引擎",
|
||||
)(list_search_engines)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@ -1,32 +1,33 @@
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from server.agent.tools import tools, tool_names
|
||||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status, dumps
|
||||
from server.agent.tools_select import tools, tool_names
|
||||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
||||
from langchain.agents import AgentExecutor, LLMSingleActionAgent
|
||||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
||||
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||
from langchain.chains import LLMChain
|
||||
from typing import AsyncIterable, Optional
|
||||
from typing import AsyncIterable, Optional, Dict
|
||||
import asyncio
|
||||
from typing import List
|
||||
from server.chat.utils import History
|
||||
import json
|
||||
|
||||
from server.agent import model_container
|
||||
from server.knowledge_base.kb_service.base import get_kb_details
|
||||
|
||||
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
{"role": "user", "content": "请使用知识库工具查询今天北京天气"},
|
||||
{"role": "assistant", "content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
prompt_name: str = Body("agent_chat",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
):
|
||||
history = [History.from_data(h) for h in history]
|
||||
@ -41,25 +42,31 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
prompt_template = CustomPromptTemplate(
|
||||
template=get_prompt_template(prompt_name),
|
||||
## 传入全局变量来实现agent调用
|
||||
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
||||
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
||||
model_container.MODEL = model
|
||||
|
||||
prompt_template = get_prompt_template("agent_chat", prompt_name)
|
||||
prompt_template_agent = CustomPromptTemplate(
|
||||
template=prompt_template,
|
||||
tools=tools,
|
||||
input_variables=["input", "intermediate_steps", "history"]
|
||||
)
|
||||
output_parser = CustomOutputParser()
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt_template)
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
|
||||
agent = LLMSingleActionAgent(
|
||||
llm_chain=llm_chain,
|
||||
output_parser=output_parser,
|
||||
stop=["Observation:", "Observation:\n", "<|im_end|>"], # Qwen模型中使用这个
|
||||
# stop=["Observation:", "Observation:\n"], # 其他模型,注意模板
|
||||
stop=["\nObservation:", "Observation:", "<|im_end|>"], # Qwen模型中使用这个
|
||||
allowed_tools=tool_names,
|
||||
)
|
||||
# 把history转成agent的memory
|
||||
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
|
||||
|
||||
for message in history:
|
||||
# 检查消息的角色
|
||||
if message.role == 'user':
|
||||
@ -73,50 +80,71 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
||||
verbose=True,
|
||||
memory=memory,
|
||||
)
|
||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
|
||||
callback.done),
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.create_task(wrap_done(
|
||||
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
|
||||
callback.done))
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
if stream:
|
||||
async for chunk in callback.aiter():
|
||||
tools_use = []
|
||||
# Use server-sent-events to stream the response
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.error:
|
||||
tools_use.append("工具调用失败:\n" + data["error"])
|
||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
||||
yield json.dumps({"answer": "(工具调用失败,请查看工具栏报错) \n\n"}, ensure_ascii=False)
|
||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||
continue
|
||||
if data["status"] == Status.agent_action:
|
||||
yield json.dumps({"answer": "(正在使用工具,请注意工具栏变化) \n\n"}, ensure_ascii=False)
|
||||
if data["status"] == Status.agent_finish:
|
||||
elif data["status"] == Status.error:
|
||||
tools_use.append("\n```\n")
|
||||
tools_use.append("工具名称: " + data["tool_name"])
|
||||
tools_use.append("工具状态: " + "调用失败")
|
||||
tools_use.append("错误信息: " + data["error"])
|
||||
tools_use.append("重新开始尝试")
|
||||
tools_use.append("\n```\n")
|
||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
||||
elif data["status"] == Status.tool_finish:
|
||||
tools_use.append("\n```\n")
|
||||
tools_use.append("工具名称: " + data["tool_name"])
|
||||
tools_use.append("工具状态: " + "调用成功")
|
||||
tools_use.append("工具输入: " + data["input_str"])
|
||||
tools_use.append("工具输出: " + data["output_str"])
|
||||
tools_use.append("\n```\n")
|
||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
||||
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
|
||||
elif data["status"] == Status.agent_finish:
|
||||
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
|
||||
else:
|
||||
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
|
||||
|
||||
|
||||
else:
|
||||
pass
|
||||
# agent必须要steram=True,这部分暂时没有完成
|
||||
# result = []
|
||||
# async for chunk in callback.aiter():
|
||||
# data = json.loads(chunk)
|
||||
# status = data["status"]
|
||||
# if status == Status.start:
|
||||
# result.append(chunk)
|
||||
# elif status == Status.running:
|
||||
# result[-1]["llm_token"] += chunk["llm_token"]
|
||||
# elif status == Status.complete:
|
||||
# result[-1]["status"] = Status.complete
|
||||
# elif status == Status.agent_finish:
|
||||
# result.append(chunk)
|
||||
# elif status == Status.agent_finish:
|
||||
# pass
|
||||
# yield dumps(result)
|
||||
answer = ""
|
||||
final_answer = ""
|
||||
async for chunk in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||
continue
|
||||
if data["status"] == Status.error:
|
||||
answer += "\n```\n"
|
||||
answer += "工具名称: " + data["tool_name"] + "\n"
|
||||
answer += "工具状态: " + "调用失败" + "\n"
|
||||
answer += "错误信息: " + data["error"] + "\n"
|
||||
answer += "\n```\n"
|
||||
if data["status"] == Status.tool_finish:
|
||||
answer += "\n```\n"
|
||||
answer += "工具名称: " + data["tool_name"] + "\n"
|
||||
answer += "工具状态: " + "调用成功" + "\n"
|
||||
answer += "工具输入: " + data["input_str"] + "\n"
|
||||
answer += "工具输出: " + data["output_str"] + "\n"
|
||||
answer += "\n```\n"
|
||||
if data["status"] == Status.agent_finish:
|
||||
final_answer = data["final_answer"]
|
||||
else:
|
||||
answer += data["llm_token"]
|
||||
|
||||
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return StreamingResponse(agent_chat_iterator(query=query,
|
||||
|
||||
@ -22,8 +22,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
@ -36,10 +37,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
prompt_template = get_prompt_template(prompt_name)
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
@ -31,9 +31,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
@ -51,12 +50,13 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
prompt_template = get_prompt_template(prompt_name)
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
@ -72,14 +72,10 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||
source_documents = []
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = os.path.split(doc.metadata["source"])[-1]
|
||||
if local_doc_url:
|
||||
url = "file://" + doc.metadata["source"]
|
||||
else:
|
||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||||
url = f"{request.base_url}knowledge_base/download_doc?" + parameters
|
||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||||
url = f"/knowledge_base/download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
|
||||
@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel):
|
||||
messages: List[OpenAiMessage]
|
||||
temperature: float = 0.7
|
||||
n: int = 1
|
||||
max_tokens: int = 1024
|
||||
max_tokens: int = None
|
||||
stop: List[str] = []
|
||||
stream: bool = False
|
||||
presence_penalty: int = 0
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
|
||||
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE)
|
||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
||||
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
||||
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
@ -11,7 +12,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict
|
||||
from server.chat.utils import History
|
||||
from langchain.docstore.document import Document
|
||||
import json
|
||||
@ -32,8 +33,49 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
return search.results(text, result_len)
|
||||
|
||||
|
||||
def metaphor_search(
|
||||
text: str,
|
||||
result_len: int = SEARCH_ENGINE_TOP_K,
|
||||
splitter_name: str = "SpacyTextSplitter",
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
) -> List[Dict]:
|
||||
from metaphor_python import Metaphor
|
||||
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||
from server.knowledge_base.utils import make_text_splitter
|
||||
|
||||
if not METAPHOR_API_KEY:
|
||||
return []
|
||||
|
||||
client = Metaphor(METAPHOR_API_KEY)
|
||||
search = client.search(text, num_results=result_len, use_autoprompt=True)
|
||||
contents = search.get_contents().contents
|
||||
|
||||
# metaphor 返回的内容都是长文本,需要分词再检索
|
||||
docs = [Document(page_content=x.extract,
|
||||
metadata={"link": x.url, "title": x.title})
|
||||
for x in contents]
|
||||
text_splitter = make_text_splitter(splitter_name=splitter_name,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
splitted_docs = text_splitter.split_documents(docs)
|
||||
|
||||
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
||||
if len(splitted_docs) > result_len:
|
||||
vs = memo_faiss_pool.new_vector_store()
|
||||
vs.add_documents(splitted_docs)
|
||||
splitted_docs = vs.similarity_search(text, k=result_len, score_threshold=1.0)
|
||||
|
||||
docs = [{"snippet": x.page_content,
|
||||
"link": x.metadata["link"],
|
||||
"title": x.metadata["title"]}
|
||||
for x in splitted_docs]
|
||||
return docs
|
||||
|
||||
|
||||
SEARCH_ENGINES = {"bing": bing_search,
|
||||
"duckduckgo": duckduckgo_search,
|
||||
"metaphor": metaphor_search,
|
||||
}
|
||||
|
||||
|
||||
@ -72,7 +114,8 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
@ -93,13 +136,14 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
prompt_template = get_prompt_template(prompt_name)
|
||||
prompt_template = get_prompt_template("search_engine_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
@ -10,10 +10,11 @@ class KnowledgeBaseModel(Base):
|
||||
__tablename__ = 'knowledge_base'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID')
|
||||
kb_name = Column(String(50), comment='知识库名称')
|
||||
kb_info = Column(String(200), comment='知识库简介(用于Agent)')
|
||||
vs_type = Column(String(50), comment='向量库类型')
|
||||
embed_model = Column(String(50), comment='嵌入模型名称')
|
||||
file_count = Column(Integer, default=0, comment='文件数量')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
def __repr__(self):
|
||||
return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}', vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>"
|
||||
return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}',kb_intro='{self.kb_info} vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>"
|
||||
|
||||
@ -37,4 +37,4 @@ class FileDocModel(Base):
|
||||
meta_data = Column(JSON, default={})
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.metadata}')>"
|
||||
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.meta_data}')>"
|
||||
|
||||
@ -3,13 +3,14 @@ from server.db.session import with_session
|
||||
|
||||
|
||||
@with_session
|
||||
def add_kb_to_db(session, kb_name, vs_type, embed_model):
|
||||
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
|
||||
# 创建知识库实例
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
||||
if not kb:
|
||||
kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model)
|
||||
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model)
|
||||
session.add(kb)
|
||||
else: # update kb with new vs_type and embed_model
|
||||
else: # update kb with new vs_type and embed_model
|
||||
kb.kb_info = kb_info
|
||||
kb.vs_type = vs_type
|
||||
kb.embed_model = embed_model
|
||||
return True
|
||||
@ -53,6 +54,7 @@ def get_kb_detail(session, kb_name: str) -> dict:
|
||||
if kb:
|
||||
return {
|
||||
"kb_name": kb.kb_name,
|
||||
"kb_info": kb.kb_info,
|
||||
"vs_type": kb.vs_type,
|
||||
"embed_model": kb.embed_model,
|
||||
"file_count": kb.file_count,
|
||||
|
||||
@ -140,7 +140,7 @@ if __name__ == "__main__":
|
||||
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
||||
pprint(ids)
|
||||
elif r == 2: # search docs
|
||||
docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0)
|
||||
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
|
||||
pprint(docs)
|
||||
if r == 3: # delete docs
|
||||
logger.warning(f"清除 {vs_name} by {name}")
|
||||
|
||||
@ -203,6 +203,20 @@ def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
def update_info(knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
kb_info:str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
kb.update_info(kb_info)
|
||||
|
||||
return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
|
||||
|
||||
|
||||
def update_docs(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
|
||||
@ -337,8 +351,9 @@ def recreate_vector_store(
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
if kb.exists():
|
||||
kb.clear_vs()
|
||||
kb.create_kb()
|
||||
kb.clear_vs()
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
|
||||
@ -19,7 +19,7 @@ from server.db.repository.knowledge_file_repository import (
|
||||
)
|
||||
|
||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_MODEL)
|
||||
EMBEDDING_MODEL, KB_INFO)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||
list_kbs_from_folder, list_files_from_folder,
|
||||
@ -42,11 +42,11 @@ class KBService(ABC):
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||
self.embed_model = embed_model
|
||||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
self.do_init()
|
||||
|
||||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||
return load_embeddings(self.embed_model, embed_device)
|
||||
|
||||
@ -63,7 +63,7 @@ class KBService(ABC):
|
||||
if not os.path.exists(self.doc_path):
|
||||
os.makedirs(self.doc_path)
|
||||
self.do_create_kb()
|
||||
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||
return status
|
||||
|
||||
def clear_vs(self):
|
||||
@ -116,6 +116,14 @@ class KBService(ABC):
|
||||
os.remove(kb_file.filepath)
|
||||
return status
|
||||
|
||||
def update_info(self, kb_info: str):
|
||||
"""
|
||||
更新知识库介绍
|
||||
"""
|
||||
self.kb_info = kb_info
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
使用content中的文件更新向量库
|
||||
@ -127,7 +135,7 @@ class KBService(ABC):
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
filename=file_name))
|
||||
filename=file_name))
|
||||
|
||||
def list_files(self):
|
||||
return list_files_from_db(self.kb_name)
|
||||
@ -271,6 +279,7 @@ def get_kb_details() -> List[Dict]:
|
||||
result[kb] = {
|
||||
"kb_name": kb,
|
||||
"vs_type": "",
|
||||
"kb_info": "",
|
||||
"embed_model": "",
|
||||
"file_count": 0,
|
||||
"create_time": None,
|
||||
|
||||
@ -89,7 +89,7 @@ class FaissKBService(KBService):
|
||||
|
||||
def do_clear_vs(self):
|
||||
with kb_faiss_pool.atomic:
|
||||
kb_faiss_pool.pop(self.kb_name)
|
||||
kb_faiss_pool.pop((self.kb_name, self.vector_name))
|
||||
shutil.rmtree(self.vs_path)
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
|
||||
98
server/knowledge_base/kb_service/zilliz_kb_service.py
Normal file
@ -0,0 +1,98 @@
|
||||
from typing import List, Dict, Optional
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Zilliz
|
||||
from configs import kbs_config
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
|
||||
class ZillizKBService(KBService):
|
||||
zilliz: Zilliz
|
||||
|
||||
@staticmethod
|
||||
def get_collection(zilliz_name):
|
||||
from pymilvus import Collection
|
||||
return Collection(zilliz_name)
|
||||
|
||||
# def save_vector_store(self):
|
||||
# if self.zilliz.col:
|
||||
# self.zilliz.col.flush()
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
if self.zilliz.col:
|
||||
data_list = self.zilliz.col.query(expr=f'pk == {id}', output_fields=["*"])
|
||||
if len(data_list) > 0:
|
||||
data = data_list[0]
|
||||
text = data.pop("text")
|
||||
return Document(page_content=text, metadata=data)
|
||||
|
||||
@staticmethod
|
||||
def search(zilliz_name, content, limit=3):
|
||||
search_params = {
|
||||
"metric_type": "IP",
|
||||
"params": {},
|
||||
}
|
||||
c = ZillizKBService.get_collection(zilliz_name)
|
||||
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
|
||||
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.ZILLIZ
|
||||
|
||||
def _load_zilliz(self, embeddings: Embeddings = None):
|
||||
if embeddings is None:
|
||||
embeddings = self._load_embeddings()
|
||||
zilliz_args = kbs_config.get("zilliz")
|
||||
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(embeddings),
|
||||
collection_name=self.kb_name, connection_args=zilliz_args)
|
||||
|
||||
|
||||
def do_init(self):
|
||||
self._load_zilliz()
|
||||
|
||||
def do_drop_kb(self):
|
||||
if self.zilliz.col:
|
||||
self.zilliz.col.release()
|
||||
self.zilliz.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
self._load_zilliz(embeddings=EmbeddingsFunAdapter(embeddings))
|
||||
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k))
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
for doc in docs:
|
||||
for k, v in doc.metadata.items():
|
||||
doc.metadata[k] = str(v)
|
||||
for field in self.zilliz.fields:
|
||||
doc.metadata.setdefault(field, "")
|
||||
doc.metadata.pop(self.zilliz._text_field, None)
|
||||
doc.metadata.pop(self.zilliz._vector_field, None)
|
||||
|
||||
ids = self.zilliz.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
if self.zilliz.col:
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
delete_list = [item.get("pk") for item in
|
||||
self.zilliz.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||
self.zilliz.col.delete(expr=f'pk in {delete_list}')
|
||||
|
||||
def do_clear_vs(self):
|
||||
if self.zilliz.col:
|
||||
self.do_drop_kb()
|
||||
self.do_init()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
from server.db.base import Base, engine
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
zillizService = ZillizKBService("test")
|
||||
|
||||
@ -37,9 +37,10 @@ def folder2db(
|
||||
kb_names: List[str],
|
||||
mode: Literal["recreate_vs", "update_in_db", "increament"],
|
||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||
kb_info: dict[str, Any] = {},
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||
):
|
||||
'''
|
||||
@ -72,6 +73,7 @@ def folder2db(
|
||||
# 清除向量库,从本地文件重建
|
||||
if mode == "recreate_vs":
|
||||
kb.clear_vs()
|
||||
kb.create_kb()
|
||||
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
|
||||
files2vs(kb_name, kb_files)
|
||||
kb.save_vector_store()
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import os
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from fastapi import Body
|
||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
||||
from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client
|
||||
|
||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT,LANGCHAIN_LLM_MODEL
|
||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||
get_httpx_client, get_model_worker_config)
|
||||
|
||||
|
||||
def list_running_models(
|
||||
@ -9,19 +9,21 @@ def list_running_models(
|
||||
placeholder: str = Body(None, description="该参数未使用,占位用"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
从fastchat controller获取已加载模型列表
|
||||
从fastchat controller获取已加载模型列表及其配置项
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
with get_httpx_client() as client:
|
||||
r = client.post(controller_address + "/list_models")
|
||||
return BaseResponse(data=r.json()["models"])
|
||||
models = r.json()["models"]
|
||||
data = {m: get_model_config(m).data for m in models}
|
||||
return BaseResponse(data=data)
|
||||
except Exception as e:
|
||||
logger.error(f'{e.__class__.__name__}: {e}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
data=[],
|
||||
data={},
|
||||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
|
||||
@ -29,7 +31,36 @@ def list_config_models() -> BaseResponse:
|
||||
'''
|
||||
从本地获取configs中配置的模型列表
|
||||
'''
|
||||
return BaseResponse(data=list_llm_models())
|
||||
configs = list_config_llm_models()
|
||||
# 删除ONLINE_MODEL配置中的敏感信息
|
||||
for config in configs["online"].values():
|
||||
del_keys = set(["worker_class"])
|
||||
for k in config:
|
||||
if "key" in k.lower() or "secret" in k.lower():
|
||||
del_keys.add(k)
|
||||
for k in del_keys:
|
||||
config.pop(k, None)
|
||||
|
||||
return BaseResponse(data=configs)
|
||||
|
||||
|
||||
def get_model_config(
|
||||
model_name: str = Body(description="配置中LLM模型的名称"),
|
||||
placeholder: str = Body(description="占位用,无实际效果")
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
获取LLM模型配置项(合并后的)
|
||||
'''
|
||||
config = get_model_worker_config(model_name=model_name)
|
||||
# 删除ONLINE_MODEL配置中的敏感信息
|
||||
del_keys = set(["worker_class"])
|
||||
for k in config:
|
||||
if "key" in k.lower() or "secret" in k.lower():
|
||||
del_keys.add(k)
|
||||
for k in del_keys:
|
||||
config.pop(k, None)
|
||||
|
||||
return BaseResponse(data=config)
|
||||
|
||||
|
||||
def stop_llm_model(
|
||||
@ -79,3 +110,9 @@ def change_llm_model(
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
|
||||
def list_search_engines() -> BaseResponse:
|
||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
||||
|
||||
return BaseResponse(data=list(SEARCH_ENGINES))
|
||||
|
||||
@ -65,7 +65,7 @@ def gen_params(appid, domain,question, temperature):
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"max_tokens": None,
|
||||
"auditing": "default",
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
@ -4,3 +4,4 @@ from .xinghuo import XingHuoWorker
|
||||
from .qianfan import QianFanWorker
|
||||
from .fangzhou import FangZhouWorker
|
||||
from .qwen import QwenWorker
|
||||
from .baichuan import BaiChuanWorker
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
# import os
|
||||
# import sys
|
||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from server.utils import get_model_worker_config, get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Literal
|
||||
from typing import List, Literal, Dict
|
||||
from configs import TEMPERATURE
|
||||
|
||||
|
||||
@ -20,29 +20,29 @@ def calculate_md5(input_string):
|
||||
return encrypted
|
||||
|
||||
|
||||
def do_request():
|
||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
api_key = ""
|
||||
secret_key = ""
|
||||
def request_baichuan_api(
|
||||
messages: List[Dict[str, str]],
|
||||
api_key: str = None,
|
||||
secret_key: str = None,
|
||||
version: str = "Baichuan2-53B",
|
||||
temperature: float = TEMPERATURE,
|
||||
model_name: str = "baichuan-api",
|
||||
):
|
||||
config = get_model_worker_config(model_name)
|
||||
api_key = api_key or config.get("api_key")
|
||||
secret_key = secret_key or config.get("secret_key")
|
||||
version = version or config.get("version")
|
||||
|
||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
data = {
|
||||
"model": "Baichuan2-53B",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "世界第一高峰是"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"temperature": 0.1,
|
||||
"top_k": 10
|
||||
}
|
||||
"model": version,
|
||||
"messages": messages,
|
||||
"parameters": {"temperature": temperature}
|
||||
}
|
||||
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = calculate_md5(secret_key + json_data + str(time_stamp))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + api_key,
|
||||
@ -52,18 +52,17 @@ def do_request():
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
|
||||
response = requests.post(url, data=json_data, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
print("请求成功!")
|
||||
print("响应header:", response.headers)
|
||||
print("响应body:", response.text)
|
||||
else:
|
||||
print("请求失败,状态码:", response.status_code)
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
resp = json.loads(line)
|
||||
yield resp
|
||||
|
||||
|
||||
class BaiChuanWorker(ApiModelWorker):
|
||||
BASE_URL = "https://api.baichuan-ai.com/v1/chat"
|
||||
BASE_URL = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
SUPPORT_MODELS = ["Baichuan2-53B"]
|
||||
|
||||
def __init__(
|
||||
@ -95,54 +94,34 @@ class BaiChuanWorker(ApiModelWorker):
|
||||
self.secret_key = config.get("secret_key")
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
data = {
|
||||
"model": self.version,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": params["prompt"]
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"temperature": params.get("temperature",TEMPERATURE),
|
||||
"top_k": params.get("top_k",1)
|
||||
}
|
||||
}
|
||||
super().generate_stream_gate(params)
|
||||
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = calculate_md5(self.secret_key + json_data + str(time_stamp))
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
"X-BC-Request-Id": "your requestId",
|
||||
"X-BC-Timestamp": str(time_stamp),
|
||||
"X-BC-Signature": signature,
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
messages = self.prompt_to_messages(params["prompt"])
|
||||
|
||||
response = requests.post(self.BASE_URL, data=json_data, headers=headers)
|
||||
text = ""
|
||||
for resp in request_baichuan_api(messages=messages,
|
||||
api_key=self.api_key,
|
||||
secret_key=self.secret_key,
|
||||
version=self.version,
|
||||
temperature=params.get("temperature")):
|
||||
if resp["code"] == 0:
|
||||
text += resp["data"]["messages"][-1]["content"]
|
||||
yield json.dumps(
|
||||
{
|
||||
"error_code": resp["code"],
|
||||
"text": text
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
else:
|
||||
yield json.dumps(
|
||||
{
|
||||
"error_code": resp["code"],
|
||||
"text": resp["msg"]
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
|
||||
if response.status_code == 200:
|
||||
resp = eval(response.text)
|
||||
yield json.dumps(
|
||||
{
|
||||
"error_code": resp["code"],
|
||||
"text": resp["data"]["messages"][-1]["content"]
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
else:
|
||||
yield json.dumps(
|
||||
{
|
||||
"error_code": resp["code"],
|
||||
"text": resp["msg"]
|
||||
},
|
||||
ensure_ascii=False
|
||||
).encode() + b"\0"
|
||||
|
||||
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from configs.basic_config import LOG_PATH
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import BaseModelWorker
|
||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||
import uuid
|
||||
import json
|
||||
import sys
|
||||
from pydantic import BaseModel
|
||||
import fastchat
|
||||
import threading
|
||||
import asyncio
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@ -40,12 +40,12 @@ class ApiModelWorker(BaseModelWorker):
|
||||
worker_addr=worker_addr,
|
||||
**kwargs)
|
||||
self.context_len = context_len
|
||||
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
|
||||
self.init_heart_beat()
|
||||
|
||||
def count_token(self, params):
|
||||
# TODO:需要完善
|
||||
print("count token")
|
||||
print(params)
|
||||
# print("count token")
|
||||
prompt = params["prompt"]
|
||||
return {"count": len(str(prompt)), "error_code": 0}
|
||||
|
||||
@ -59,16 +59,7 @@ class ApiModelWorker(BaseModelWorker):
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
# print(params)
|
||||
|
||||
# help methods
|
||||
def get_config(self):
|
||||
|
||||
@ -26,10 +26,10 @@ class ChatGLMWorker(ApiModelWorker):
|
||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
||||
self.conv = conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["Human", "Assistant"],
|
||||
sep="\n### ",
|
||||
sep="\n###",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
@ -57,7 +57,7 @@ class ChatGLMWorker(ApiModelWorker):
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
# print(params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
177
server/utils.py
@ -5,12 +5,11 @@ from fastapi import FastAPI
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
|
||||
logger, log_verbose,
|
||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, LANGCHAIN_LLM_MODEL, logger, log_verbose,
|
||||
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
|
||||
import httpx
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
||||
|
||||
@ -34,23 +33,70 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
def get_ChatOpenAI(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int = None,
|
||||
streaming: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> ChatOpenAI:
|
||||
config = get_model_worker_config(model_name)
|
||||
model = ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
**kwargs
|
||||
)
|
||||
## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
|
||||
config_models = list_config_llm_models()
|
||||
if model_name in config_models.get("langchain", {}):
|
||||
config = config_models["langchain"][model_name]
|
||||
if model_name == "Azure-OpenAI":
|
||||
model = AzureChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
deployment_name=config.get("deployment_name"),
|
||||
model_version=config.get("model_version"),
|
||||
openai_api_type=config.get("openai_api_type"),
|
||||
openai_api_base=config.get("api_base_url"),
|
||||
openai_api_version=config.get("api_version"),
|
||||
openai_api_key=config.get("api_key"),
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
elif model_name == "OpenAI":
|
||||
model = ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=config.get("model_name"),
|
||||
openai_api_base=config.get("api_base_url"),
|
||||
openai_api_key=config.get("api_key"),
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
elif model_name == "Anthropic":
|
||||
model = ChatAnthropic(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=config.get("model_name"),
|
||||
anthropic_api_key=config.get("api_key"),
|
||||
|
||||
)
|
||||
## TODO 支持其他的Langchain原生支持的模型
|
||||
else:
|
||||
## 非Langchain原生支持的模型,走Fschat封装
|
||||
config = get_model_worker_config(model_name)
|
||||
model = ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@ -144,7 +190,7 @@ def run_async(cor):
|
||||
return loop.run_until_complete(cor)
|
||||
|
||||
|
||||
def iter_over_async(ait, loop):
|
||||
def iter_over_async(ait, loop=None):
|
||||
'''
|
||||
将异步生成器封装成同步生成器.
|
||||
'''
|
||||
@ -157,6 +203,12 @@ def iter_over_async(ait, loop):
|
||||
except StopAsyncIteration:
|
||||
return True, None
|
||||
|
||||
if loop is None:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
while True:
|
||||
done, obj = loop.run_until_complete(get_next())
|
||||
if done:
|
||||
@ -194,7 +246,7 @@ def MakeFastAPIOffline(
|
||||
index = i
|
||||
break
|
||||
if isinstance(index, int):
|
||||
app.routes.pop(i)
|
||||
app.routes.pop(index)
|
||||
|
||||
# Set up static file mount
|
||||
app.mount(
|
||||
@ -241,8 +293,9 @@ def MakeFastAPIOffline(
|
||||
redoc_favicon_url=favicon,
|
||||
)
|
||||
|
||||
# 从model_config中获取模型信息
|
||||
|
||||
|
||||
# 从model_config中获取模型信息
|
||||
def list_embed_models() -> List[str]:
|
||||
'''
|
||||
get names of configured embedding models
|
||||
@ -250,17 +303,18 @@ def list_embed_models() -> List[str]:
|
||||
return list(MODEL_PATH["embed_model"])
|
||||
|
||||
|
||||
def list_llm_models() -> Dict[str, List[str]]:
|
||||
def list_config_llm_models() -> Dict[str, Dict]:
|
||||
'''
|
||||
get names of configured llm models with different types.
|
||||
get configured llm models with different types.
|
||||
return [(model_name, config_type), ...]
|
||||
'''
|
||||
workers = list(FSCHAT_MODEL_WORKERS)
|
||||
if "default" in workers:
|
||||
workers.remove("default")
|
||||
if LLM_MODEL not in workers:
|
||||
workers.insert(0, LLM_MODEL)
|
||||
return {
|
||||
"local": list(MODEL_PATH["llm_model"]),
|
||||
"online": list(ONLINE_LLM_MODEL),
|
||||
"local": MODEL_PATH["llm_model"],
|
||||
"langchain": LANGCHAIN_LLM_MODEL,
|
||||
"online": ONLINE_LLM_MODEL,
|
||||
"worker": workers,
|
||||
}
|
||||
|
||||
@ -293,12 +347,13 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||
|
||||
|
||||
# 从server_config中获取服务信息
|
||||
|
||||
def get_model_worker_config(model_name: str = None) -> dict:
|
||||
'''
|
||||
加载model worker的配置项。
|
||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||
'''
|
||||
from configs.model_config import ONLINE_LLM_MODEL
|
||||
from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from server import model_workers
|
||||
|
||||
@ -307,6 +362,10 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||
|
||||
# 在线模型API
|
||||
if model_name in LANGCHAIN_LLM_MODEL:
|
||||
config["langchain_model"] = True
|
||||
config["worker_class"] = ""
|
||||
|
||||
if model_name in ONLINE_LLM_MODEL:
|
||||
config["online_api"] = True
|
||||
if provider := config.get("provider"):
|
||||
@ -316,9 +375,10 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
||||
msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
|
||||
config["model_path"] = get_model_path(model_name)
|
||||
config["device"] = llm_device(config.get("device"))
|
||||
# 本地模型
|
||||
if model_name in MODEL_PATH["llm_model"]:
|
||||
config["model_path"] = get_model_path(model_name)
|
||||
config["device"] = llm_device(config.get("device"))
|
||||
return config
|
||||
|
||||
|
||||
@ -335,6 +395,8 @@ def fschat_controller_address() -> str:
|
||||
from configs.server_config import FSCHAT_CONTROLLER
|
||||
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
@ -342,6 +404,8 @@ def fschat_controller_address() -> str:
|
||||
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
|
||||
if model := get_model_worker_config(model_name):
|
||||
host = model["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = model["port"]
|
||||
return f"http://{host}:{port}"
|
||||
return ""
|
||||
@ -351,6 +415,8 @@ def fschat_openai_api_address() -> str:
|
||||
from configs.server_config import FSCHAT_OPENAI_API
|
||||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
return f"http://{host}:{port}/v1"
|
||||
|
||||
@ -359,6 +425,8 @@ def api_address() -> str:
|
||||
from configs.server_config import API_SERVER
|
||||
|
||||
host = API_SERVER["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
port = API_SERVER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
@ -371,15 +439,16 @@ def webui_address() -> str:
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def get_prompt_template(name: str) -> Optional[str]:
|
||||
def get_prompt_template(type: str, name: str) -> Optional[str]:
|
||||
'''
|
||||
从prompt_config中加载模板内容
|
||||
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
|
||||
'''
|
||||
|
||||
from configs import prompt_config
|
||||
import importlib
|
||||
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
||||
|
||||
return prompt_config.PROMPT_TEMPLATES.get(name)
|
||||
return prompt_config.PROMPT_TEMPLATES[type].get(name)
|
||||
|
||||
|
||||
def set_httpx_config(
|
||||
@ -391,6 +460,7 @@ def set_httpx_config(
|
||||
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
|
||||
对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。
|
||||
'''
|
||||
|
||||
import httpx
|
||||
import os
|
||||
|
||||
@ -433,14 +503,15 @@ def set_httpx_config(
|
||||
|
||||
# TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。
|
||||
# patch requests to use custom proxies instead of system settings
|
||||
# def _get_proxies():
|
||||
# return {}
|
||||
def _get_proxies():
|
||||
return proxies
|
||||
|
||||
# import urllib.request
|
||||
# urllib.request.getproxies = _get_proxies
|
||||
import urllib.request
|
||||
urllib.request.getproxies = _get_proxies
|
||||
|
||||
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||
|
||||
|
||||
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||
try:
|
||||
import torch
|
||||
@ -541,3 +612,37 @@ def get_httpx_client(
|
||||
return httpx.AsyncClient(**kwargs)
|
||||
else:
|
||||
return httpx.Client(**kwargs)
|
||||
|
||||
|
||||
def get_server_configs() -> Dict:
|
||||
'''
|
||||
获取configs中的原始配置项,供前端使用
|
||||
'''
|
||||
from configs.kb_config import (
|
||||
DEFAULT_KNOWLEDGE_BASE,
|
||||
DEFAULT_SEARCH_ENGINE,
|
||||
DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
SCORE_THRESHOLD,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
ZH_TITLE_ENHANCE,
|
||||
text_splitter_dict,
|
||||
TEXT_SPLITTER_NAME,
|
||||
)
|
||||
from configs.model_config import (
|
||||
LLM_MODEL,
|
||||
EMBEDDING_MODEL,
|
||||
HISTORY_LEN,
|
||||
TEMPERATURE,
|
||||
)
|
||||
from configs.prompt_config import PROMPT_TEMPLATES
|
||||
|
||||
_custom = {
|
||||
"controller_address": fschat_controller_address(),
|
||||
"openai_api_address": fschat_openai_api_address(),
|
||||
"api_address": api_address(),
|
||||
}
|
||||
|
||||
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
||||
|
||||
50
startup.py
@ -7,6 +7,7 @@ from multiprocessing import Process
|
||||
from datetime import datetime
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
# 设置numexpr最大线程数,默认为CPU核心数
|
||||
try:
|
||||
import numexpr
|
||||
@ -26,6 +27,7 @@ from configs import (
|
||||
TEXT_SPLITTER_NAME,
|
||||
FSCHAT_CONTROLLER,
|
||||
FSCHAT_OPENAI_API,
|
||||
FSCHAT_MODEL_WORKERS,
|
||||
API_SERVER,
|
||||
WEBUI_SERVER,
|
||||
HTTPX_DEFAULT_TIMEOUT,
|
||||
@ -66,7 +68,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
controller_address:
|
||||
worker_address:
|
||||
|
||||
|
||||
对于Langchain支持的模型:
|
||||
langchain_model:True
|
||||
不会使用fschat
|
||||
对于online_api:
|
||||
online_api:True
|
||||
worker_class: `provider`
|
||||
@ -76,31 +80,34 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
"""
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import worker_id, logger
|
||||
import argparse
|
||||
logger.setLevel(log_level)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args([])
|
||||
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
|
||||
from fastchat.serve.base_model_worker import app
|
||||
worker = ""
|
||||
# 在线模型API
|
||||
if worker_class := kwargs.get("worker_class"):
|
||||
from fastchat.serve.model_worker import app
|
||||
elif worker_class := kwargs.get("worker_class"):
|
||||
from fastchat.serve.base_model_worker import app
|
||||
|
||||
worker = worker_class(model_names=args.model_names,
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
|
||||
# 本地模型
|
||||
else:
|
||||
from configs.model_config import VLLM_MODEL_DICT
|
||||
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
|
||||
import fastchat.serve.vllm_worker
|
||||
from fastchat.serve.vllm_worker import VLLMWorker,app
|
||||
from fastchat.serve.vllm_worker import VLLMWorker, app
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
|
||||
|
||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
||||
args.tokenizer_mode = 'auto'
|
||||
args.trust_remote_code= True
|
||||
@ -114,7 +121,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
args.block_size = 16
|
||||
args.swap_space = 4 # GiB
|
||||
args.gpu_memory_utilization = 0.90
|
||||
args.max_num_batched_tokens = 2560
|
||||
args.max_num_batched_tokens = 16384 # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
|
||||
args.max_num_seqs = 256
|
||||
args.disable_log_stats = False
|
||||
args.conv_template = None
|
||||
@ -123,6 +130,13 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
||||
args.engine_use_ray = False
|
||||
args.disable_log_requests = False
|
||||
|
||||
# 0.2.0 vllm后要加的参数, 但是这里不需要
|
||||
args.max_model_len = None
|
||||
args.revision = None
|
||||
args.quantization = None
|
||||
args.max_log_len = None
|
||||
|
||||
if args.model_path:
|
||||
args.model = args.model_path
|
||||
if args.num_gpus > 1:
|
||||
@ -146,12 +160,14 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
conv_template = args.conv_template,
|
||||
)
|
||||
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
||||
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||
# sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
||||
|
||||
else:
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||
|
||||
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
||||
args.max_gpu_memory = "20GiB"
|
||||
args.max_gpu_memory = "22GiB"
|
||||
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
||||
|
||||
args.load_8bit = False
|
||||
@ -163,7 +179,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
args.awq_ckpt = None
|
||||
args.awq_wbits = 16
|
||||
args.awq_groupsize = -1
|
||||
args.model_names = []
|
||||
args.model_names = [""]
|
||||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.stream_interval = 2
|
||||
@ -212,8 +228,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
# sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].logger.setLevel(log_level)
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
||||
@ -659,7 +675,9 @@ async def start_main_server():
|
||||
if args.api_worker:
|
||||
configs = get_all_model_worker_configs()
|
||||
for model_name, config in configs.items():
|
||||
if config.get("online_api") and config.get("worker_class"):
|
||||
if (config.get("online_api")
|
||||
and config.get("worker_class")
|
||||
and model_name in FSCHAT_MODEL_WORKERS):
|
||||
e = manager.Event()
|
||||
model_worker_started.append(e)
|
||||
process = Process(
|
||||
|
||||
@ -137,6 +137,14 @@ def test_search_docs(api="/knowledge_base/search_docs"):
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_info(api="/knowledge_base/update_info"):
|
||||
url = api_base_url + api
|
||||
print("\n更新知识库介绍")
|
||||
r = requests.post(url, json={"knowledge_base_name": "samples", "kb_info": "你好"})
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
|
||||
def test_update_docs(api="/knowledge_base/update_docs"):
|
||||
url = api_base_url + api
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False)
|
||||
api: ApiRequest = ApiRequest(api_base_url)
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
|
||||
@ -32,7 +32,7 @@ def get_running_models(api="/llm_model/list_models"):
|
||||
return []
|
||||
|
||||
|
||||
def test_running_models(api="/llm_model/list_models"):
|
||||
def test_running_models(api="/llm_model/list_running_models"):
|
||||
url = api_base_url + api
|
||||
r = requests.post(url)
|
||||
assert r.status_code == 200
|
||||
@ -48,7 +48,7 @@ def test_running_models(api="/llm_model/list_models"):
|
||||
# r = requests.post(url, json={""})
|
||||
|
||||
|
||||
def test_change_model(api="/llm_model/change"):
|
||||
def test_change_model(api="/llm_model/change_model"):
|
||||
url = api_base_url + api
|
||||
|
||||
running_models = get_running_models()
|
||||
|
||||
16
tests/online_api/test_baichuan.py
Normal file
@ -0,0 +1,16 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
|
||||
from server.model_workers.baichuan import request_baichuan_api
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def test_qwen():
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
|
||||
for x in request_baichuan_api(messages):
|
||||
print(type(x))
|
||||
pprint(x)
|
||||
assert x["code"] == 0
|
||||
6
webui.py
@ -21,12 +21,6 @@ if __name__ == "__main__":
|
||||
}
|
||||
)
|
||||
|
||||
if not chat_box.chat_inited:
|
||||
st.toast(
|
||||
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
||||
f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了."
|
||||
)
|
||||
|
||||
pages = {
|
||||
"对话": {
|
||||
"icon": "chat",
|
||||
|
||||
@ -2,10 +2,9 @@ import streamlit as st
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from datetime import datetime
|
||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
||||
import os
|
||||
from configs import LLM_MODEL, TEMPERATURE
|
||||
from server.utils import get_model_worker_config
|
||||
from configs import (LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE,LANGCHAIN_LLM_MODEL)
|
||||
from typing import List, Dict
|
||||
|
||||
chat_box = ChatBox(
|
||||
@ -14,8 +13,6 @@ chat_box = ChatBox(
|
||||
"chatchat_icon_blue_square_v2.png"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||
'''
|
||||
返回消息历史。
|
||||
@ -36,9 +33,32 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
||||
return chat_box.filter_history(history_len=history_len, filter=filter)
|
||||
|
||||
|
||||
def dialogue_page(api: ApiRequest):
|
||||
chat_box.init_session()
|
||||
def get_default_llm_model(api: ApiRequest) -> (str, bool):
|
||||
'''
|
||||
从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回
|
||||
返回类型为(model_name, is_local_model)
|
||||
'''
|
||||
running_models = api.list_running_models()
|
||||
if not running_models:
|
||||
return "", False
|
||||
|
||||
if LLM_MODEL in running_models:
|
||||
return LLM_MODEL, True
|
||||
|
||||
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
||||
if local_models:
|
||||
return local_models[0], True
|
||||
return list(running_models)[0], False
|
||||
|
||||
|
||||
def dialogue_page(api: ApiRequest):
|
||||
if not chat_box.chat_inited:
|
||||
default_model = get_default_llm_model(api)[0]
|
||||
st.toast(
|
||||
f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n"
|
||||
f"当前运行的模型`{default_model}`, 您可以开始提问了."
|
||||
)
|
||||
chat_box.init_session()
|
||||
with st.sidebar:
|
||||
# TODO: 对话模型与会话绑定
|
||||
def on_mode_change():
|
||||
@ -49,7 +69,6 @@ def dialogue_page(api: ApiRequest):
|
||||
if cur_kb:
|
||||
text = f"{text} 当前知识库: `{cur_kb}`。"
|
||||
st.toast(text)
|
||||
# sac.alert(text, description="descp", type="success", closable=True, banner=True)
|
||||
|
||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||
["LLM 对话",
|
||||
@ -57,31 +76,38 @@ def dialogue_page(api: ApiRequest):
|
||||
"搜索引擎问答",
|
||||
"自定义Agent问答",
|
||||
],
|
||||
index=1,
|
||||
index=0,
|
||||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
)
|
||||
|
||||
def on_llm_change():
|
||||
config = get_model_worker_config(llm_model)
|
||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
||||
if llm_model:
|
||||
config = api.get_model_config(llm_model)
|
||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
||||
|
||||
def llm_model_format_func(x):
|
||||
if x in running_models:
|
||||
return f"{x} (Running)"
|
||||
return x
|
||||
|
||||
running_models = api.list_running_models()
|
||||
running_models = list(api.list_running_models())
|
||||
running_models += LANGCHAIN_LLM_MODEL.keys()
|
||||
available_models = []
|
||||
config_models = api.list_config_models()
|
||||
for models in config_models.values():
|
||||
for m in models:
|
||||
if m not in running_models:
|
||||
available_models.append(m)
|
||||
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
|
||||
for m in worker_models:
|
||||
if m not in running_models and m != "default":
|
||||
available_models.append(m)
|
||||
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
|
||||
if not v.get("provider") and k not in running_models:
|
||||
available_models.append(k)
|
||||
for k, v in config_models.get("langchain", {}).items(): # 列出LANGCHAIN_LLM_MODEL支持的模型
|
||||
available_models.append(k)
|
||||
llm_models = running_models + available_models
|
||||
index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL))
|
||||
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
llm_models,
|
||||
index,
|
||||
@ -90,7 +116,8 @@ def dialogue_page(api: ApiRequest):
|
||||
key="llm_model",
|
||||
)
|
||||
if (st.session_state.get("prev_llm_model") != llm_model
|
||||
and not get_model_worker_config(llm_model).get("online_api")
|
||||
and not llm_model in config_models.get("online", {})
|
||||
and not llm_model in config_models.get("langchain", {})
|
||||
and llm_model not in running_models):
|
||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||
prev_model = st.session_state.get("prev_llm_model")
|
||||
@ -101,9 +128,29 @@ def dialogue_page(api: ApiRequest):
|
||||
st.success(msg)
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.01)
|
||||
index_prompt = {
|
||||
"LLM 对话": "llm_chat",
|
||||
"自定义Agent问答": "agent_chat",
|
||||
"搜索引擎问答": "search_engine_chat",
|
||||
"知识库问答": "knowledge_base_chat",
|
||||
}
|
||||
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
||||
prompt_template_name = prompt_templates_kb_list[0]
|
||||
if "prompt_template_select" not in st.session_state:
|
||||
st.session_state.prompt_template_select = prompt_templates_kb_list[0]
|
||||
|
||||
## 部分模型可以超过10抡对话
|
||||
def prompt_change():
|
||||
text = f"已切换为 {prompt_template_name} 模板。"
|
||||
st.toast(text)
|
||||
|
||||
prompt_template_select = st.selectbox(
|
||||
"请选择Prompt模板:",
|
||||
prompt_templates_kb_list,
|
||||
index=0,
|
||||
on_change=prompt_change,
|
||||
key="prompt_template_select",
|
||||
)
|
||||
prompt_template_name = st.session_state.prompt_template_select
|
||||
history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
|
||||
|
||||
def on_kb_change():
|
||||
@ -111,10 +158,14 @@ def dialogue_page(api: ApiRequest):
|
||||
|
||||
if dialogue_mode == "知识库问答":
|
||||
with st.expander("知识库配置", True):
|
||||
kb_list = api.list_knowledge_bases(no_remote_api=True)
|
||||
kb_list = api.list_knowledge_bases()
|
||||
index = 0
|
||||
if DEFAULT_KNOWLEDGE_BASE in kb_list:
|
||||
index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
|
||||
selected_kb = st.selectbox(
|
||||
"请选择知识库:",
|
||||
kb_list,
|
||||
index=index,
|
||||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
@ -123,15 +174,17 @@ def dialogue_page(api: ApiRequest):
|
||||
## Bge 模型会超过1
|
||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
|
||||
|
||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
search_engine_list = list(SEARCH_ENGINES.keys())
|
||||
search_engine_list = api.list_search_engines()
|
||||
if DEFAULT_SEARCH_ENGINE in search_engine_list:
|
||||
index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
|
||||
else:
|
||||
index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
|
||||
with st.expander("搜索引擎配置", True):
|
||||
search_engine = st.selectbox(
|
||||
label="请选择搜索引擎",
|
||||
options=search_engine_list,
|
||||
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
|
||||
index=index,
|
||||
)
|
||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
|
||||
|
||||
@ -147,7 +200,11 @@ def dialogue_page(api: ApiRequest):
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature)
|
||||
r = api.chat_chat(prompt,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
@ -157,30 +214,42 @@ def dialogue_page(api: ApiRequest):
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
|
||||
|
||||
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
chat_box.ai_say([
|
||||
f"正在思考和寻找工具 ...",])
|
||||
f"正在思考...",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
])
|
||||
text = ""
|
||||
element_index = 0
|
||||
ans = ""
|
||||
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
||||
if not any(agent in llm_model for agent in support_agent):
|
||||
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n"
|
||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||
for d in api.agent_chat(prompt,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
temperature=temperature):
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
):
|
||||
try:
|
||||
d = json.loads(d)
|
||||
except:
|
||||
pass
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
|
||||
elif chunk := d.get("answer"):
|
||||
if chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
elif chunk := d.get("tools"):
|
||||
element_index += 1
|
||||
chat_box.insert_msg(Markdown("...", in_expander=True, title="使用工具...", state="complete"))
|
||||
chat_box.update_msg("\n\n".join(d.get("tools", [])), element_index=element_index, streaming=False)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
if chunk := d.get("final_answer"):
|
||||
ans += chunk
|
||||
chat_box.update_msg(ans, element_index=0)
|
||||
if chunk := d.get("tools"):
|
||||
text += "\n\n".join(d.get("tools", []))
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||
chat_box.update_msg(text, element_index=1, streaming=False)
|
||||
elif dialogue_mode == "知识库问答":
|
||||
chat_box.ai_say([
|
||||
f"正在查询知识库 `{selected_kb}` ...",
|
||||
@ -193,6 +262,7 @@ def dialogue_page(api: ApiRequest):
|
||||
score_threshold=score_threshold,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
@ -212,6 +282,7 @@ def dialogue_page(api: ApiRequest):
|
||||
top_k=se_top_k,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
@ -239,4 +310,4 @@ def dialogue_page(api: ApiRequest):
|
||||
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
|
||||
mime="text/markdown",
|
||||
use_container_width=True,
|
||||
)
|
||||
)
|
||||
@ -63,6 +63,9 @@ def knowledge_base_page(api: ApiRequest):
|
||||
else:
|
||||
selected_kb_index = 0
|
||||
|
||||
if "selected_kb_info" not in st.session_state:
|
||||
st.session_state["selected_kb_info"] = ""
|
||||
|
||||
def format_selected_kb(kb_name: str) -> str:
|
||||
if kb := kb_list.get(kb_name):
|
||||
return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})"
|
||||
@ -84,6 +87,11 @@ def knowledge_base_page(api: ApiRequest):
|
||||
placeholder="新知识库名称,不支持中文命名",
|
||||
key="kb_name",
|
||||
)
|
||||
kb_info = st.text_input(
|
||||
"知识库简介",
|
||||
placeholder="知识库简介,方便Agent查找",
|
||||
key="kb_info",
|
||||
)
|
||||
|
||||
cols = st.columns(2)
|
||||
|
||||
@ -123,18 +131,23 @@ def knowledge_base_page(api: ApiRequest):
|
||||
)
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_kb_name"] = kb_name
|
||||
st.session_state["selected_kb_info"] = kb_info
|
||||
st.experimental_rerun()
|
||||
|
||||
elif selected_kb:
|
||||
kb = selected_kb
|
||||
|
||||
|
||||
st.session_state["selected_kb_info"] = kb_list[kb]['kb_info']
|
||||
# 上传文件
|
||||
files = st.file_uploader("上传知识文件:",
|
||||
[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True,
|
||||
)
|
||||
kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, key=None,
|
||||
help=None, on_change=None, args=None, kwargs=None)
|
||||
|
||||
if kb_info != st.session_state["selected_kb_info"]:
|
||||
st.session_state["selected_kb_info"] = kb_info
|
||||
api.update_kb_info(kb, kb_info)
|
||||
|
||||
# with st.sidebar:
|
||||
with st.expander(
|
||||
|
||||