mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 23:43:30 +08:00
merge PR 659: make chat api support streaming
This commit is contained in:
commit
18d453cc18
5
.gitignore
vendored
5
.gitignore
vendored
@ -174,4 +174,7 @@ embedding/*
|
|||||||
|
|
||||||
pyrightconfig.json
|
pyrightconfig.json
|
||||||
loader/tmp_files
|
loader/tmp_files
|
||||||
flagged/*
|
flagged/*
|
||||||
|
ptuning-v2/*.json
|
||||||
|
ptuning-v2/*.bin
|
||||||
|
|
||||||
|
|||||||
32
README.md
32
README.md
@ -23,13 +23,17 @@
|
|||||||
|
|
||||||
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
|
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
|
||||||
|
|
||||||
|
🐳 Docker镜像:registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 (感谢 @InkSong🌲 )
|
||||||
|
|
||||||
|
💻 运行方式:docker run -d -p 80:7860 --gpus all registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0
|
||||||
|
|
||||||
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
|
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
|
||||||
|
|
||||||
📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
|
📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
|
||||||
|
|
||||||
## 变更日志
|
## 变更日志
|
||||||
|
|
||||||
参见 [变更日志](docs/CHANGELOG.md)。
|
参见 [版本更新日志](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)。
|
||||||
|
|
||||||
## 硬件需求
|
## 硬件需求
|
||||||
|
|
||||||
@ -60,6 +64,23 @@
|
|||||||
|
|
||||||
本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。
|
本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。
|
||||||
|
|
||||||
|
## Docker 整合包
|
||||||
|
🐳 Docker镜像地址:`registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 `🌲
|
||||||
|
|
||||||
|
💻 一行命令运行:
|
||||||
|
```shell
|
||||||
|
docker run -d -p 80:7860 --gpus all registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
- 该版本镜像大小`25.2G`,使用[v0.1.16](https://github.com/imClumsyPanda/langchain-ChatGLM/releases/tag/v0.1.16),以`nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04`为基础镜像
|
||||||
|
- 该版本内置两个`embedding`模型:`m3e-base`,`text2vec-large-chinese`,内置`fastchat+chatglm-6b`
|
||||||
|
- 该版本目标为方便一键部署使用,请确保您已经在Linux发行版上安装了NVIDIA驱动程序
|
||||||
|
- 请注意,您不需要在主机系统上安装CUDA工具包,但需要安装`NVIDIA Driver`以及`NVIDIA Container Toolkit`,请参考[安装指南](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
|
||||||
|
- 首次拉取和启动均需要一定时间,首次启动时请参照下图使用`docker logs -f <container id>`查看日志
|
||||||
|
- 如遇到启动过程卡在`Waiting..`步骤,建议使用`docker exec -it <container id> bash`进入`/logs/`目录查看对应阶段日志
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
## Docker 部署
|
## Docker 部署
|
||||||
为了能让容器使用主机GPU资源,需要在主机上安装 [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit)。具体安装步骤如下:
|
为了能让容器使用主机GPU资源,需要在主机上安装 [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit)。具体安装步骤如下:
|
||||||
```shell
|
```shell
|
||||||
@ -198,12 +219,17 @@ Web UI 可以实现如下功能:
|
|||||||
- [ ] 知识图谱/图数据库接入
|
- [ ] 知识图谱/图数据库接入
|
||||||
- [ ] Agent 实现
|
- [ ] Agent 实现
|
||||||
- [x] 增加更多 LLM 模型支持
|
- [x] 增加更多 LLM 模型支持
|
||||||
|
- [x] [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
|
||||||
- [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
- [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
||||||
- [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8)
|
- [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8)
|
||||||
- [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4)
|
- [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4)
|
||||||
- [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
|
- [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
|
||||||
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
|
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
|
||||||
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
|
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
|
||||||
|
- [x] [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1)
|
||||||
|
- [x] [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b)
|
||||||
|
- [x] [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
|
||||||
|
- [x] [lmsys/vicuna-13b-delta-v1.1](https://huggingface.co/lmsys/vicuna-13b-delta-v1.1)
|
||||||
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
|
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
|
||||||
- [x] 增加更多 Embedding 模型支持
|
- [x] 增加更多 Embedding 模型支持
|
||||||
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
|
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
|
||||||
@ -221,7 +247,7 @@ Web UI 可以实现如下功能:
|
|||||||
- [x] 选择知识库开始问答
|
- [x] 选择知识库开始问答
|
||||||
- [x] 上传文件/文件夹至知识库
|
- [x] 上传文件/文件夹至知识库
|
||||||
- [x] 知识库测试
|
- [x] 知识库测试
|
||||||
- [ ] 删除知识库中文件
|
- [x] 删除知识库中文件
|
||||||
- [x] 支持搜索引擎问答
|
- [x] 支持搜索引擎问答
|
||||||
- [ ] 增加 API 支持
|
- [ ] 增加 API 支持
|
||||||
- [x] 利用 fastapi 实现 API 部署方式
|
- [x] 利用 fastapi 实现 API 部署方式
|
||||||
@ -229,7 +255,7 @@ Web UI 可以实现如下功能:
|
|||||||
- [x] VUE 前端
|
- [x] VUE 前端
|
||||||
|
|
||||||
## 项目交流群
|
## 项目交流群
|
||||||
<img src="img/qr_code_33.jpg" alt="二维码" width="300" height="300" />
|
<img src="img/qr_code_45.jpg" alt="二维码" width="300" height="300" />
|
||||||
|
|
||||||
|
|
||||||
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||||
|
|||||||
231
api.py
231
api.py
@ -1,10 +1,11 @@
|
|||||||
|
#encoding:utf-8
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import urllib
|
import urllib
|
||||||
|
import asyncio
|
||||||
import nltk
|
import nltk
|
||||||
import pydantic
|
import pydantic
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -55,7 +56,7 @@ class ListDocsResponse(BaseResponse):
|
|||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
question: str = pydantic.Field(..., description="Question text")
|
question: str = pydantic.Field(..., description="Question text")
|
||||||
response: str = pydantic.Field(..., description="Response text")
|
response: str = pydantic.Field(..., description="Response text")
|
||||||
history: List[List[str]] = pydantic.Field(..., description="History text")
|
history: List[List[Optional[str]]] = pydantic.Field(..., description="History text")
|
||||||
source_documents: List[str] = pydantic.Field(
|
source_documents: List[str] = pydantic.Field(
|
||||||
..., description="List of source documents and their scores"
|
..., description="List of source documents and their scores"
|
||||||
)
|
)
|
||||||
@ -80,23 +81,37 @@ class ChatMessage(BaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_folder_path(local_doc_id: str):
|
def get_kb_path(local_doc_id: str):
|
||||||
return os.path.join(KB_ROOT_PATH, local_doc_id, "content")
|
return os.path.join(KB_ROOT_PATH, local_doc_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_doc_path(local_doc_id: str):
|
||||||
|
return os.path.join(get_kb_path(local_doc_id), "content")
|
||||||
|
|
||||||
|
|
||||||
def get_vs_path(local_doc_id: str):
|
def get_vs_path(local_doc_id: str):
|
||||||
return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store")
|
return os.path.join(get_kb_path(local_doc_id), "vector_store")
|
||||||
|
|
||||||
|
|
||||||
def get_file_path(local_doc_id: str, doc_name: str):
|
def get_file_path(local_doc_id: str, doc_name: str):
|
||||||
return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name)
|
return os.path.join(get_doc_path(local_doc_id), doc_name)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
|
# 检查是否包含预期外的字符或路径攻击关键字
|
||||||
|
if "../" in knowledge_base_id:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
file: UploadFile = File(description="A single binary file"),
|
file: UploadFile = File(description="A single binary file"),
|
||||||
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
||||||
):
|
):
|
||||||
saved_path = get_folder_path(knowledge_base_id)
|
if not validate_kb_name(knowledge_base_id):
|
||||||
|
return BaseResponse(code=403, msg="Don't attack me", data=[])
|
||||||
|
|
||||||
|
saved_path = get_doc_path(knowledge_base_id)
|
||||||
if not os.path.exists(saved_path):
|
if not os.path.exists(saved_path):
|
||||||
os.makedirs(saved_path)
|
os.makedirs(saved_path)
|
||||||
|
|
||||||
@ -126,21 +141,25 @@ async def upload_files(
|
|||||||
],
|
],
|
||||||
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
||||||
):
|
):
|
||||||
saved_path = get_folder_path(knowledge_base_id)
|
if not validate_kb_name(knowledge_base_id):
|
||||||
|
return BaseResponse(code=403, msg="Don't attack me", data=[])
|
||||||
|
|
||||||
|
saved_path = get_doc_path(knowledge_base_id)
|
||||||
if not os.path.exists(saved_path):
|
if not os.path.exists(saved_path):
|
||||||
os.makedirs(saved_path)
|
os.makedirs(saved_path)
|
||||||
filelist = []
|
filelist = []
|
||||||
for file in files:
|
for file in files:
|
||||||
file_content = ''
|
file_content = ''
|
||||||
file_path = os.path.join(saved_path, file.filename)
|
file_path = os.path.join(saved_path, file.filename)
|
||||||
file_content = file.file.read()
|
file_content = await file.read()
|
||||||
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
||||||
continue
|
continue
|
||||||
with open(file_path, "ab+") as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(file_content)
|
f.write(file_content)
|
||||||
filelist.append(file_path)
|
filelist.append(file_path)
|
||||||
if filelist:
|
if filelist:
|
||||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
|
vs_path = get_vs_path(knowledge_base_id)
|
||||||
|
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
|
||||||
if len(loaded_files):
|
if len(loaded_files):
|
||||||
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success"
|
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success"
|
||||||
return BaseResponse(code=200, msg=file_status)
|
return BaseResponse(code=200, msg=file_status)
|
||||||
@ -164,16 +183,24 @@ async def list_kbs():
|
|||||||
|
|
||||||
|
|
||||||
async def list_docs(
|
async def list_docs(
|
||||||
knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1")
|
knowledge_base_id: str = Query(..., description="Knowledge Base Name", example="kb1")
|
||||||
):
|
):
|
||||||
local_doc_folder = get_folder_path(knowledge_base_id)
|
if not validate_kb_name(knowledge_base_id):
|
||||||
|
return ListDocsResponse(code=403, msg="Don't attack me", data=[])
|
||||||
|
|
||||||
|
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||||
|
kb_path = get_kb_path(knowledge_base_id)
|
||||||
|
local_doc_folder = get_doc_path(knowledge_base_id)
|
||||||
|
if not os.path.exists(kb_path):
|
||||||
|
return ListDocsResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found", data=[])
|
||||||
if not os.path.exists(local_doc_folder):
|
if not os.path.exists(local_doc_folder):
|
||||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
all_doc_names = []
|
||||||
all_doc_names = [
|
else:
|
||||||
doc
|
all_doc_names = [
|
||||||
for doc in os.listdir(local_doc_folder)
|
doc
|
||||||
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
for doc in os.listdir(local_doc_folder)
|
||||||
]
|
if os.path.isfile(os.path.join(local_doc_folder, doc))
|
||||||
|
]
|
||||||
return ListDocsResponse(data=all_doc_names)
|
return ListDocsResponse(data=all_doc_names)
|
||||||
|
|
||||||
|
|
||||||
@ -182,11 +209,15 @@ async def delete_kb(
|
|||||||
description="Knowledge Base Name",
|
description="Knowledge Base Name",
|
||||||
example="kb1"),
|
example="kb1"),
|
||||||
):
|
):
|
||||||
|
if not validate_kb_name(knowledge_base_id):
|
||||||
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
# TODO: 确认是否支持批量删除知识库
|
# TODO: 确认是否支持批量删除知识库
|
||||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
kb_path = get_kb_path(knowledge_base_id)
|
||||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
if not os.path.exists(kb_path):
|
||||||
shutil.rmtree(get_folder_path(knowledge_base_id))
|
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||||
|
shutil.rmtree(kb_path)
|
||||||
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
|
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
|
||||||
|
|
||||||
|
|
||||||
@ -195,27 +226,30 @@ async def delete_doc(
|
|||||||
description="Knowledge Base Name",
|
description="Knowledge Base Name",
|
||||||
example="kb1"),
|
example="kb1"),
|
||||||
doc_name: str = Query(
|
doc_name: str = Query(
|
||||||
None, description="doc name", example="doc_name_1.pdf"
|
..., description="doc name", example="doc_name_1.pdf"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
if not validate_kb_name(knowledge_base_id):
|
||||||
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
if not os.path.exists(get_kb_path(knowledge_base_id)):
|
||||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||||
doc_path = get_file_path(knowledge_base_id, doc_name)
|
doc_path = get_file_path(knowledge_base_id, doc_name)
|
||||||
if os.path.exists(doc_path):
|
if os.path.exists(doc_path):
|
||||||
os.remove(doc_path)
|
os.remove(doc_path)
|
||||||
remain_docs = await list_docs(knowledge_base_id)
|
remain_docs = await list_docs(knowledge_base_id)
|
||||||
if len(remain_docs.data) == 0:
|
if len(remain_docs.data) == 0:
|
||||||
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
|
shutil.rmtree(get_kb_path(knowledge_base_id), ignore_errors=True)
|
||||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||||
else:
|
else:
|
||||||
status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
||||||
if "success" in status:
|
if "success" in status:
|
||||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||||
else:
|
else:
|
||||||
return BaseResponse(code=1, msg=f"document {doc_name} delete fail")
|
return BaseResponse(code=500, msg=f"document {doc_name} delete fail")
|
||||||
else:
|
else:
|
||||||
return BaseResponse(code=1, msg=f"document {doc_name} not found")
|
return BaseResponse(code=404, msg=f"document {doc_name} not found")
|
||||||
|
|
||||||
|
|
||||||
async def update_doc(
|
async def update_doc(
|
||||||
@ -223,23 +257,26 @@ async def update_doc(
|
|||||||
description="知识库名",
|
description="知识库名",
|
||||||
example="kb1"),
|
example="kb1"),
|
||||||
old_doc: str = Query(
|
old_doc: str = Query(
|
||||||
None, description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
|
..., description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
|
||||||
),
|
),
|
||||||
new_doc: UploadFile = File(description="待上传文件"),
|
new_doc: UploadFile = File(description="待上传文件"),
|
||||||
):
|
):
|
||||||
|
if not validate_kb_name(knowledge_base_id):
|
||||||
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
|
||||||
if not os.path.exists(get_folder_path(knowledge_base_id)):
|
if not os.path.exists(get_kb_path(knowledge_base_id)):
|
||||||
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
|
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||||
doc_path = get_file_path(knowledge_base_id, old_doc)
|
doc_path = get_file_path(knowledge_base_id, old_doc)
|
||||||
if not os.path.exists(doc_path):
|
if not os.path.exists(doc_path):
|
||||||
return BaseResponse(code=1, msg=f"document {old_doc} not found")
|
return BaseResponse(code=404, msg=f"document {old_doc} not found")
|
||||||
else:
|
else:
|
||||||
os.remove(doc_path)
|
os.remove(doc_path)
|
||||||
delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
|
||||||
if "fail" in delete_status:
|
if "fail" in delete_status:
|
||||||
return BaseResponse(code=1, msg=f"document {old_doc} delete failed")
|
return BaseResponse(code=500, msg=f"document {old_doc} delete failed")
|
||||||
else:
|
else:
|
||||||
saved_path = get_folder_path(knowledge_base_id)
|
saved_path = get_doc_path(knowledge_base_id)
|
||||||
if not os.path.exists(saved_path):
|
if not os.path.exists(saved_path):
|
||||||
os.makedirs(saved_path)
|
os.makedirs(saved_path)
|
||||||
|
|
||||||
@ -267,8 +304,8 @@ async def update_doc(
|
|||||||
async def local_doc_chat(
|
async def local_doc_chat(
|
||||||
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
||||||
history: List[List[str]] = Body(
|
history: List[List[Optional[str]]] = Body(
|
||||||
[],
|
[],
|
||||||
description="History of previous questions and answers",
|
description="History of previous questions and answers",
|
||||||
example=[
|
example=[
|
||||||
@ -281,7 +318,7 @@ async def local_doc_chat(
|
|||||||
):
|
):
|
||||||
vs_path = get_vs_path(knowledge_base_id)
|
vs_path = get_vs_path(knowledge_base_id)
|
||||||
if not os.path.exists(vs_path):
|
if not os.path.exists(vs_path):
|
||||||
# return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found")
|
# return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
|
||||||
return ChatMessage(
|
return ChatMessage(
|
||||||
question=question,
|
question=question,
|
||||||
response=f"Knowledge base {knowledge_base_id} not found",
|
response=f"Knowledge base {knowledge_base_id} not found",
|
||||||
@ -289,7 +326,7 @@ async def local_doc_chat(
|
|||||||
source_documents=[],
|
source_documents=[],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (stream):
|
if (streaming):
|
||||||
def generate_answer ():
|
def generate_answer ():
|
||||||
last_print_len = 0
|
last_print_len = 0
|
||||||
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
|
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
|
||||||
@ -300,7 +337,7 @@ async def local_doc_chat(
|
|||||||
|
|
||||||
return StreamingResponse(generate_answer())
|
return StreamingResponse(generate_answer())
|
||||||
else:
|
else:
|
||||||
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@ -314,14 +351,14 @@ async def local_doc_chat(
|
|||||||
return ChatMessage(
|
return ChatMessage(
|
||||||
question=question,
|
question=question,
|
||||||
response=resp["result"],
|
response=resp["result"],
|
||||||
history=next_history,
|
history=history,
|
||||||
source_documents=source_documents,
|
source_documents=source_documents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def bing_search_chat(
|
async def bing_search_chat(
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: Optional[List[List[str]]] = Body(
|
history: Optional[List[List[Optional[str]]]] = Body(
|
||||||
[],
|
[],
|
||||||
description="History of previous questions and answers",
|
description="History of previous questions and answers",
|
||||||
example=[
|
example=[
|
||||||
@ -351,8 +388,8 @@ async def bing_search_chat(
|
|||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"),
|
||||||
history: List[List[str]] = Body(
|
history: List[List[Optional[str]]] = Body(
|
||||||
[],
|
[],
|
||||||
description="History of previous questions and answers",
|
description="History of previous questions and answers",
|
||||||
example=[
|
example=[
|
||||||
@ -363,19 +400,20 @@ async def chat(
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
if (streaming):
|
||||||
if (stream):
|
|
||||||
def generate_answer ():
|
def generate_answer ():
|
||||||
last_print_len = 0
|
last_print_len = 0
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
streaming=True):
|
{"prompt": question, "history": history, "streaming": True})
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
yield answer_result.llm_output["answer"][last_print_len:]
|
yield answer_result.llm_output["answer"][last_print_len:]
|
||||||
last_print_len = len(answer_result.llm_output["answer"])
|
last_print_len = len(answer_result.llm_output["answer"])
|
||||||
|
|
||||||
return StreamingResponse(generate_answer())
|
return StreamingResponse(generate_answer())
|
||||||
else:
|
else:
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
streaming=True):
|
{"prompt": question, "history": history, "streaming": True})
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
pass
|
pass
|
||||||
@ -386,9 +424,22 @@ async def chat(
|
|||||||
history=history,
|
history=history,
|
||||||
source_documents=[],
|
source_documents=[],
|
||||||
)
|
)
|
||||||
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
|
{"prompt": question, "history": history, "streaming": True})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
|
resp = answer_result.llm_output["answer"]
|
||||||
|
history = answer_result.history
|
||||||
|
pass
|
||||||
|
return ChatMessage(
|
||||||
|
question=question,
|
||||||
|
response=resp,
|
||||||
|
history=history,
|
||||||
|
source_documents=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
async def stream_chat(websocket: WebSocket):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
turn = 1
|
turn = 1
|
||||||
while True:
|
while True:
|
||||||
@ -408,6 +459,7 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
|||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||||
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
query=question, vs_path=vs_path, chat_history=history, streaming=True
|
||||||
):
|
):
|
||||||
|
await asyncio.sleep(0)
|
||||||
await websocket.send_text(resp["result"][last_print_len:])
|
await websocket.send_text(resp["result"][last_print_len:])
|
||||||
last_print_len = len(resp["result"])
|
last_print_len = len(resp["result"])
|
||||||
|
|
||||||
@ -430,17 +482,51 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
|||||||
)
|
)
|
||||||
turn += 1
|
turn += 1
|
||||||
|
|
||||||
|
async def stream_chat_bing(websocket: WebSocket):
|
||||||
|
"""
|
||||||
|
基于bing搜索的流式问答
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
turn = 1
|
||||||
|
while True:
|
||||||
|
input_json = await websocket.receive_json()
|
||||||
|
question, history = input_json["question"], input_json["history"]
|
||||||
|
|
||||||
|
await websocket.send_json({"question": question, "turn": turn, "flag": "start"})
|
||||||
|
|
||||||
|
last_print_len = 0
|
||||||
|
for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history, streaming=True):
|
||||||
|
await websocket.send_text(resp["result"][last_print_len:])
|
||||||
|
last_print_len = len(resp["result"])
|
||||||
|
|
||||||
|
source_documents = [
|
||||||
|
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||||
|
f"""相关度:{doc.metadata['score']}\n\n"""
|
||||||
|
for inum, doc in enumerate(resp["source_documents"])
|
||||||
|
]
|
||||||
|
|
||||||
|
await websocket.send_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"question": question,
|
||||||
|
"turn": turn,
|
||||||
|
"flag": "end",
|
||||||
|
"sources_documents": source_documents,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
turn += 1
|
||||||
|
|
||||||
async def document():
|
async def document():
|
||||||
return RedirectResponse(url="/docs")
|
return RedirectResponse(url="/docs")
|
||||||
|
|
||||||
|
|
||||||
def api_start(host, port):
|
def api_start(host, port, **kwargs):
|
||||||
global app
|
global app
|
||||||
global local_doc_qa
|
global local_doc_qa
|
||||||
|
|
||||||
llm_model_ins = shared.loaderLLM()
|
llm_model_ins = shared.loaderLLM()
|
||||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
# Add CORS middleware to allow all origins
|
# Add CORS middleware to allow all origins
|
||||||
@ -454,21 +540,28 @@ def api_start(host, port):
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat)
|
# 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id
|
||||||
|
app.websocket("/local_doc_qa/stream_chat")(stream_chat)
|
||||||
|
|
||||||
app.get("/", response_model=BaseResponse)(document)
|
app.get("/", response_model=BaseResponse, summary="swagger 文档")(document)
|
||||||
|
|
||||||
app.post("/chat", response_model=ChatMessage)(chat)
|
# 增加基于bing搜索的流式问答
|
||||||
|
# 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia
|
||||||
|
# 强烈推荐开源的insomnia
|
||||||
|
# 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing
|
||||||
|
app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing)
|
||||||
|
|
||||||
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
|
app.post("/chat", response_model=ChatMessage, summary="与模型对话")(chat)
|
||||||
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
|
||||||
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
app.post("/local_doc_qa/upload_file", response_model=BaseResponse, summary="上传文件到知识库")(upload_file)
|
||||||
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
|
app.post("/local_doc_qa/upload_files", response_model=BaseResponse, summary="批量上传文件到知识库")(upload_files)
|
||||||
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse)(list_kbs)
|
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage, summary="与知识库对话")(local_doc_chat)
|
||||||
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage, summary="与必应搜索对话")(bing_search_chat)
|
||||||
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse)(delete_kb)
|
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse, summary="获取知识库列表")(list_kbs)
|
||||||
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_doc)
|
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse, summary="获取知识库内的文件列表")(list_docs)
|
||||||
app.post("/local_doc_qa/update_file", response_model=BaseResponse)(update_doc)
|
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库")(delete_kb)
|
||||||
|
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse, summary="删除知识库内的文件")(delete_doc)
|
||||||
|
app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(update_doc)
|
||||||
|
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
local_doc_qa.init_cfg(
|
local_doc_qa.init_cfg(
|
||||||
@ -477,15 +570,21 @@ def api_start(host, port):
|
|||||||
embedding_device=EMBEDDING_DEVICE,
|
embedding_device=EMBEDDING_DEVICE,
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
)
|
)
|
||||||
uvicorn.run(app, host=host, port=port)
|
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||||
|
uvicorn.run(app, host=host, port=port, ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||||
|
ssl_certfile=kwargs.get("ssl_certfile"))
|
||||||
|
else:
|
||||||
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
parser.add_argument("--port", type=int, default=7861)
|
parser.add_argument("--port", type=int, default=7861)
|
||||||
|
parser.add_argument("--ssl_keyfile", type=str)
|
||||||
|
parser.add_argument("--ssl_certfile", type=str)
|
||||||
# 初始化消息
|
# 初始化消息
|
||||||
args = None
|
args = None
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
api_start(args.host, args.port)
|
api_start(args.host, args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile)
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from typing import List
|
|||||||
from utils import torch_gc
|
from utils import torch_gc
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pypinyin import lazy_pinyin
|
from pypinyin import lazy_pinyin
|
||||||
from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
|
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult)
|
AnswerResult)
|
||||||
from models.loader.args import parser
|
from models.loader.args import parser
|
||||||
@ -18,6 +17,7 @@ from agent import bing_search
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from textsplitter.zh_title_enhance import zh_title_enhance
|
from textsplitter.zh_title_enhance import zh_title_enhance
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
|
||||||
# patch HuggingFaceEmbeddings to make it hashable
|
# patch HuggingFaceEmbeddings to make it hashable
|
||||||
@ -58,6 +58,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
|
|||||||
|
|
||||||
|
|
||||||
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE):
|
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE):
|
||||||
|
|
||||||
if filepath.lower().endswith(".md"):
|
if filepath.lower().endswith(".md"):
|
||||||
loader = UnstructuredFileLoader(filepath, mode="elements")
|
loader = UnstructuredFileLoader(filepath, mode="elements")
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
@ -66,10 +67,14 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
|
|||||||
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(textsplitter)
|
docs = loader.load_and_split(textsplitter)
|
||||||
elif filepath.lower().endswith(".pdf"):
|
elif filepath.lower().endswith(".pdf"):
|
||||||
|
# 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
|
||||||
|
from loader import UnstructuredPaddlePDFLoader
|
||||||
loader = UnstructuredPaddlePDFLoader(filepath)
|
loader = UnstructuredPaddlePDFLoader(filepath)
|
||||||
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
|
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(textsplitter)
|
docs = loader.load_and_split(textsplitter)
|
||||||
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
|
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
|
||||||
|
# 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
|
||||||
|
from loader import UnstructuredPaddleImageLoader
|
||||||
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
||||||
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
|
||||||
docs = loader.load_and_split(text_splitter=textsplitter)
|
docs = loader.load_and_split(text_splitter=textsplitter)
|
||||||
@ -119,7 +124,7 @@ def search_result2docs(search_results):
|
|||||||
|
|
||||||
|
|
||||||
class LocalDocQA:
|
class LocalDocQA:
|
||||||
llm: BaseAnswer = None
|
llm_model_chain: Chain = None
|
||||||
embeddings: object = None
|
embeddings: object = None
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K
|
top_k: int = VECTOR_SEARCH_TOP_K
|
||||||
chunk_size: int = CHUNK_SIZE
|
chunk_size: int = CHUNK_SIZE
|
||||||
@ -129,10 +134,10 @@ class LocalDocQA:
|
|||||||
def init_cfg(self,
|
def init_cfg(self,
|
||||||
embedding_model: str = EMBEDDING_MODEL,
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
embedding_device=EMBEDDING_DEVICE,
|
embedding_device=EMBEDDING_DEVICE,
|
||||||
llm_model: BaseAnswer = None,
|
llm_model: Chain = None,
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
):
|
):
|
||||||
self.llm = llm_model
|
self.llm_model_chain = llm_model
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
||||||
model_kwargs={'device': embedding_device})
|
model_kwargs={'device': embedding_device})
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
@ -200,6 +205,7 @@ class LocalDocQA:
|
|||||||
return vs_path, loaded_files
|
return vs_path, loaded_files
|
||||||
else:
|
else:
|
||||||
logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。")
|
logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。")
|
||||||
|
|
||||||
return None, loaded_files
|
return None, loaded_files
|
||||||
|
|
||||||
def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size):
|
def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size):
|
||||||
@ -235,8 +241,10 @@ class LocalDocQA:
|
|||||||
else:
|
else:
|
||||||
prompt = query
|
prompt = query
|
||||||
|
|
||||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
answer_result_stream_result = self.llm_model_chain(
|
||||||
streaming=streaming):
|
{"prompt": prompt, "history": chat_history, "streaming": streaming})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
history[-1][0] = query
|
history[-1][0] = query
|
||||||
@ -275,8 +283,10 @@ class LocalDocQA:
|
|||||||
result_docs = search_result2docs(results)
|
result_docs = search_result2docs(results)
|
||||||
prompt = generate_prompt(result_docs, query)
|
prompt = generate_prompt(result_docs, query)
|
||||||
|
|
||||||
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
|
answer_result_stream_result = self.llm_model_chain(
|
||||||
streaming=streaming):
|
{"prompt": prompt, "history": chat_history, "streaming": streaming})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
history[-1][0] = query
|
history[-1][0] = query
|
||||||
@ -295,7 +305,7 @@ class LocalDocQA:
|
|||||||
def update_file_from_vector_store(self,
|
def update_file_from_vector_store(self,
|
||||||
filepath: str or List[str],
|
filepath: str or List[str],
|
||||||
vs_path,
|
vs_path,
|
||||||
docs: List[Document],):
|
docs: List[Document], ):
|
||||||
vector_store = load_vector_store(vs_path, self.embeddings)
|
vector_store = load_vector_store(vs_path, self.embeddings)
|
||||||
status = vector_store.update_doc(filepath, docs)
|
status = vector_store.update_doc(filepath, docs)
|
||||||
return status
|
return status
|
||||||
@ -319,7 +329,6 @@ if __name__ == "__main__":
|
|||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
llm_model_ins = shared.loaderLLM()
|
llm_model_ins = shared.loaderLLM()
|
||||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
|
||||||
|
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||||
|
|||||||
@ -1,34 +0,0 @@
|
|||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
||||||
|
|
||||||
from typing import Any, List
|
|
||||||
|
|
||||||
|
|
||||||
class MyEmbeddings(HuggingFaceEmbeddings):
|
|
||||||
def __init__(self, **kwargs: Any):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: The list of texts to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embeddings, one for each text.
|
|
||||||
"""
|
|
||||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
|
||||||
embeddings = self.client.encode(texts, normalize_embeddings=True)
|
|
||||||
return embeddings.tolist()
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
"""Compute query embeddings using a HuggingFace transformer model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to embed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Embeddings for the text.
|
|
||||||
"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
embedding = self.client.encode(text, normalize_embeddings=True)
|
|
||||||
return embedding.tolist()
|
|
||||||
@ -1,121 +0,0 @@
|
|||||||
from langchain.vectorstores import FAISS
|
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Dict
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
from langchain.docstore.base import Docstore
|
|
||||||
|
|
||||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
import uuid
|
|
||||||
from langchain.docstore.in_memory import InMemoryDocstore
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def dependable_faiss_import() -> Any:
|
|
||||||
"""Import faiss if available, otherwise raise error."""
|
|
||||||
try:
|
|
||||||
import faiss
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import faiss python package. "
|
|
||||||
"Please install it with `pip install faiss` "
|
|
||||||
"or `pip install faiss-cpu` (depending on Python version)."
|
|
||||||
)
|
|
||||||
return faiss
|
|
||||||
|
|
||||||
class FAISSVS(FAISS):
|
|
||||||
def __init__(self,
|
|
||||||
embedding_function: Callable[..., Any],
|
|
||||||
index: Any,
|
|
||||||
docstore: Docstore,
|
|
||||||
index_to_docstore_id: Dict[int, str]):
|
|
||||||
super().__init__(embedding_function, index, docstore, index_to_docstore_id)
|
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
|
||||||
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
"""Return docs selected using the maximal marginal relevance.
|
|
||||||
|
|
||||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
||||||
among selected documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding: Embedding to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents with scores selected by maximal marginal relevance.
|
|
||||||
"""
|
|
||||||
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
|
|
||||||
# -1 happens when not enough docs are returned.
|
|
||||||
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
|
|
||||||
mmr_selected = maximal_marginal_relevance(
|
|
||||||
np.array([embedding], dtype=np.float32), embeddings, k=k
|
|
||||||
)
|
|
||||||
selected_indices = [indices[0][i] for i in mmr_selected]
|
|
||||||
selected_scores = [scores[0][i] for i in mmr_selected]
|
|
||||||
docs = []
|
|
||||||
for i, score in zip(selected_indices, selected_scores):
|
|
||||||
if i == -1:
|
|
||||||
# This happens when not enough docs are returned.
|
|
||||||
continue
|
|
||||||
_id = self.index_to_docstore_id[i]
|
|
||||||
doc = self.docstore.search(_id)
|
|
||||||
if not isinstance(doc, Document):
|
|
||||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
|
||||||
docs.append((doc, score))
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 4,
|
|
||||||
fetch_k: int = 20,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
"""Return docs selected using the maximal marginal relevance.
|
|
||||||
|
|
||||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
||||||
among selected documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents with scores selected by maximal marginal relevance.
|
|
||||||
"""
|
|
||||||
embedding = self.embedding_function(query)
|
|
||||||
docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __from(
|
|
||||||
cls,
|
|
||||||
texts: List[str],
|
|
||||||
embeddings: List[List[float]],
|
|
||||||
embedding: Embeddings,
|
|
||||||
metadatas: Optional[List[dict]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> FAISS:
|
|
||||||
faiss = dependable_faiss_import()
|
|
||||||
index = faiss.IndexFlatIP(len(embeddings[0]))
|
|
||||||
index.add(np.array(embeddings, dtype=np.float32))
|
|
||||||
|
|
||||||
# # my code, for speeding up search
|
|
||||||
# quantizer = faiss.IndexFlatL2(len(embeddings[0]))
|
|
||||||
# index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
|
|
||||||
# index.train(np.array(embeddings, dtype=np.float32))
|
|
||||||
# index.add(np.array(embeddings, dtype=np.float32))
|
|
||||||
|
|
||||||
documents = []
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
metadata = metadatas[i] if metadatas else {}
|
|
||||||
documents.append(Document(page_content=text, metadata=metadata))
|
|
||||||
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
|
|
||||||
docstore = InMemoryDocstore(
|
|
||||||
{index_to_id[i]: doc for i, doc in enumerate(documents)}
|
|
||||||
)
|
|
||||||
return cls(embedding.embed_query, index, docstore, index_to_id)
|
|
||||||
|
|
||||||
6
cli.py
6
cli.py
@ -42,7 +42,9 @@ def start():
|
|||||||
@start.command(name="api", context_settings=dict(help_option_names=['-h', '--help']))
|
@start.command(name="api", context_settings=dict(help_option_names=['-h', '--help']))
|
||||||
@click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.')
|
@click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.')
|
||||||
@click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.')
|
@click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.')
|
||||||
def start_api(ip, port):
|
@click.option('-k', '--ssl_keyfile', type=int, help='enable api https/wss service, specify the ssl keyfile path.')
|
||||||
|
@click.option('-c', '--ssl_certfile', type=int, help='enable api https/wss service, specify the ssl certificate file path.')
|
||||||
|
def start_api(ip, port, **kwargs):
|
||||||
# 调用api_start之前需要先loadCheckPoint,并传入加载检查点的参数,
|
# 调用api_start之前需要先loadCheckPoint,并传入加载检查点的参数,
|
||||||
# 理论上可以用click包进行包装,但过于繁琐,改动较大,
|
# 理论上可以用click包进行包装,但过于繁琐,改动较大,
|
||||||
# 此处仍用parser包,并以models.loader.args.DEFAULT_ARGS的参数为默认参数
|
# 此处仍用parser包,并以models.loader.args.DEFAULT_ARGS的参数为默认参数
|
||||||
@ -51,7 +53,7 @@ def start_api(ip, port):
|
|||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.loader.args import DEFAULT_ARGS
|
from models.loader.args import DEFAULT_ARGS
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(DEFAULT_ARGS)
|
shared.loaderCheckPoint = LoaderCheckPoint(DEFAULT_ARGS)
|
||||||
api_start(host=ip, port=port)
|
api_start(host=ip, port=port, **kwargs)
|
||||||
|
|
||||||
# # 通过cli.py调用cli_demo时需要在cli.py里初始化模型,否则会报错:
|
# # 通过cli.py调用cli_demo时需要在cli.py里初始化模型,否则会报错:
|
||||||
# langchain-ChatGLM: error: unrecognized arguments: start cli
|
# langchain-ChatGLM: error: unrecognized arguments: start cli
|
||||||
|
|||||||
24
cli_demo.py
24
cli_demo.py
@ -23,11 +23,33 @@ def main():
|
|||||||
top_k=VECTOR_SEARCH_TOP_K)
|
top_k=VECTOR_SEARCH_TOP_K)
|
||||||
vs_path = None
|
vs_path = None
|
||||||
while not vs_path:
|
while not vs_path:
|
||||||
|
print("注意输入的路径是完整的文件路径,例如knowledge_base/`knowledge_base_id`/content/file.md,多个路径用英文逗号分割")
|
||||||
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
|
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
|
||||||
|
|
||||||
# 判断 filepath 是否为空,如果为空的话,重新让用户输入,防止用户误触回车
|
# 判断 filepath 是否为空,如果为空的话,重新让用户输入,防止用户误触回车
|
||||||
if not filepath:
|
if not filepath:
|
||||||
continue
|
continue
|
||||||
vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath)
|
|
||||||
|
# 支持加载多个文件
|
||||||
|
filepath = filepath.split(",")
|
||||||
|
# filepath错误的返回为None, 如果直接用原先的vs_path,_ = local_doc_qa.init_knowledge_vector_store(filepath)
|
||||||
|
# 会直接导致TypeError: cannot unpack non-iterable NoneType object而使得程序直接退出
|
||||||
|
# 因此需要先加一层判断,保证程序能继续运行
|
||||||
|
temp,loaded_files = local_doc_qa.init_knowledge_vector_store(filepath)
|
||||||
|
if temp is not None:
|
||||||
|
vs_path = temp
|
||||||
|
# 如果loaded_files和len(filepath)不一致,则说明部分文件没有加载成功
|
||||||
|
# 如果是路径错误,则应该支持重新加载
|
||||||
|
if len(loaded_files) != len(filepath):
|
||||||
|
reload_flag = eval(input("部分文件加载失败,若提示路径不存在,可重新加载,是否重新加载,输入True或False: "))
|
||||||
|
if reload_flag:
|
||||||
|
vs_path = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"the loaded vs_path is 加载的vs_path为: {vs_path}")
|
||||||
|
else:
|
||||||
|
print("load file failed, re-input your local knowledge file path 请重新输入本地知识文件路径")
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
while True:
|
while True:
|
||||||
query = input("Input your question 请输入问题:")
|
query = input("Input your question 请输入问题:")
|
||||||
|
|||||||
19
docs/FAQ.md
19
docs/FAQ.md
@ -177,3 +177,22 @@ download_with_progressbar(url, tmp_path)
|
|||||||
Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a new connection: [Errno 110] Connection timed out`
|
Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a new connection: [Errno 110] Connection timed out`
|
||||||
|
|
||||||
这是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG--!
|
这是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG--!
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`
|
||||||
|
|
||||||
|
疑为chatglm的quantization的问题或torch版本差异问题,针对已经变为Parameter的torch.zeros矩阵也执行Parameter操作,从而抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在chatglm-项目的原始文件中的quantization.py文件374行改为:
|
||||||
|
|
||||||
|
```
|
||||||
|
try:
|
||||||
|
self.weight =Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
如果上述方式不起作用,则在.cache/hugggingface/modules/目录下针对chatglm项目的原始文件中的quantization.py文件执行上述操作,若软链接不止一个,按照错误提示选择正确的路径。
|
||||||
|
|
||||||
|
注:虽然模型可以顺利加载但在cpu上仍存在推理失败的可能:即针对每个问题,模型一直输出gugugugu。
|
||||||
|
|
||||||
|
因此,最好不要试图用cpu加载量化模型,原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。
|
||||||
|
|||||||
@ -44,4 +44,12 @@ $ pip install -r requirements.txt
|
|||||||
$ python loader/image_loader.py
|
$ python loader/image_loader.py
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
注:使用 `langchain.document_loaders.UnstructuredFileLoader` 进行非结构化文件接入时,可能需要依据文档进行其他依赖包的安装,请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)。
|
注:使用 `langchain.document_loaders.UnstructuredFileLoader` 进行非结构化文件接入时,可能需要依据文档进行其他依赖包的安装,请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)。
|
||||||
|
|
||||||
|
## llama-cpp模型调用的说明
|
||||||
|
|
||||||
|
1. 首先从huggingface hub中下载对应的模型,如 [https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/) 的 [ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。
|
||||||
|
2. 将下载的模型重命名。通过huggingface_hub下载的模型会被重命名为随机序列,因此需要重命名为原始文件名,如[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)。
|
||||||
|
3. 基于下载模型的ggml的加载时间,推测对应的llama-cpp版本,下载对应的llama-cpp-python库的wheel文件,实测[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)与llama-cpp-python库兼容,然后手动安装wheel文件。
|
||||||
|
4. 将下载的模型信息写入configs/model_config.py文件里 `llm_model_dict`中,注意保证参数的兼容性,一些参数组合可能会报错.
|
||||||
|
|||||||
37
docs/启动API服务.md
Normal file
37
docs/启动API服务.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# 启动API服务
|
||||||
|
|
||||||
|
## 通过py文件启动
|
||||||
|
可以通过直接执行`api.py`文件启动API服务,默认以ip:0.0.0.0和port:7861启动http和ws服务。
|
||||||
|
```shell
|
||||||
|
python api.py
|
||||||
|
```
|
||||||
|
同时,启动时支持StartOption所列的模型加载参数,同时还支持IP和端口设置。
|
||||||
|
```shell
|
||||||
|
python api.py --model-name chatglm-6b-int8 --port 7862
|
||||||
|
```
|
||||||
|
|
||||||
|
## 通过cli.bat/cli.sh启动
|
||||||
|
也可以通过命令行控制文件继续启动。
|
||||||
|
```shell
|
||||||
|
cli.sh api --help
|
||||||
|
```
|
||||||
|
其他可设置参数和上述py文件启动方式相同。
|
||||||
|
|
||||||
|
|
||||||
|
# 以https、wss启动API服务
|
||||||
|
## 本地创建ssl相关证书文件
|
||||||
|
如果没有正式签发的CA证书,可以[安装mkcert](https://github.com/FiloSottile/mkcert#installation)工具, 然后用如下指令生成本地CA证书:
|
||||||
|
```shell
|
||||||
|
mkcert -install
|
||||||
|
mkcert api.example.com 47.123.123.123 localhost 127.0.0.1 ::1
|
||||||
|
```
|
||||||
|
默认回车保存在当前目录下,会有以生成指令第一个域名命名为前缀命名的两个pem文件。
|
||||||
|
|
||||||
|
附带两个文件参数启动即可。
|
||||||
|
````shell
|
||||||
|
python api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem
|
||||||
|
|
||||||
|
./cli.sh api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem
|
||||||
|
````
|
||||||
|
|
||||||
|
此外可以通过前置Nginx转发实现类似效果,可另行查阅相关资料。
|
||||||
BIN
img/docker_logs.png
Normal file
BIN
img/docker_logs.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 69 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 143 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 154 KiB |
BIN
img/qr_code_45.jpg
Normal file
BIN
img/qr_code_45.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 192 KiB |
@ -5,9 +5,6 @@ from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
|||||||
from paddleocr import PaddleOCR
|
from paddleocr import PaddleOCR
|
||||||
import os
|
import os
|
||||||
import nltk
|
import nltk
|
||||||
from configs.model_config import NLTK_DATA_PATH
|
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
||||||
|
|
||||||
class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
|
class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
|
||||||
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
|
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
|
||||||
@ -35,6 +32,10 @@ class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
from configs.model_config import NLTK_DATA_PATH
|
||||||
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.jpg")
|
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.jpg")
|
||||||
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from .chatglm_llm import ChatGLM
|
from .chatglm_llm import ChatGLMLLMChain
|
||||||
from .llama_llm import LLamaLLM
|
from .llama_llm import LLamaLLMChain
|
||||||
from .moss_llm import MOSSLLM
|
from .fastchat_openai_llm import FastChatOpenAILLMChain
|
||||||
from .fastchat_openai_llm import FastChatOpenAILLM
|
from .moss_llm import MOSSLLMChain
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
from models.base.base import (
|
from models.base.base import (
|
||||||
AnswerResult,
|
AnswerResult,
|
||||||
BaseAnswer
|
BaseAnswer,
|
||||||
)
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
from models.base.remote_rpc_model import (
|
from models.base.remote_rpc_model import (
|
||||||
RemoteRpcModel
|
RemoteRpcModel
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnswerResult",
|
"AnswerResult",
|
||||||
"BaseAnswer",
|
"BaseAnswer",
|
||||||
"RemoteRpcModel",
|
"RemoteRpcModel",
|
||||||
|
"AnswerResultStream",
|
||||||
|
"AnswerResultQueueSentinelTokenListenerQueue"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,16 +1,30 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, List
|
from typing import Any, Dict, List, Optional, Generator
|
||||||
import traceback
|
import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
|
from pydantic import BaseModel
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from models.loader import LoaderCheckPoint
|
|
||||||
|
|
||||||
|
|
||||||
class AnswerResult:
|
class ListenerToken:
|
||||||
|
"""
|
||||||
|
观测结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_ids: torch.LongTensor
|
||||||
|
_scores: torch.FloatTensor
|
||||||
|
|
||||||
|
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self._scores = _scores
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerResult(BaseModel):
|
||||||
"""
|
"""
|
||||||
消息实体
|
消息实体
|
||||||
"""
|
"""
|
||||||
@ -18,6 +32,122 @@ class AnswerResult:
|
|||||||
llm_output: Optional[dict] = None
|
llm_output: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerResultStream:
|
||||||
|
def __init__(self, callback_func=None):
|
||||||
|
self.callback_func = callback_func
|
||||||
|
|
||||||
|
def __call__(self, answerResult: AnswerResult):
|
||||||
|
if self.callback_func is not None:
|
||||||
|
self.callback_func(answerResult)
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
|
||||||
|
"""
|
||||||
|
定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult
|
||||||
|
实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数,
|
||||||
|
通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件
|
||||||
|
当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束
|
||||||
|
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制
|
||||||
|
"""
|
||||||
|
|
||||||
|
listenerQueue: deque = deque(maxlen=1)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
transformers.StoppingCriteria.__init__(self)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
"""
|
||||||
|
每次响应时将数据添加到响应队列
|
||||||
|
:param input_ids:
|
||||||
|
:param _scores:
|
||||||
|
:param kwargs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Iteratorize:
|
||||||
|
"""
|
||||||
|
Transforms a function that takes a callback
|
||||||
|
into a lazy iterator (generator).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func, kwargs={}):
|
||||||
|
self.mfunc = func
|
||||||
|
self.q = Queue()
|
||||||
|
self.sentinel = object()
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.stop_now = False
|
||||||
|
|
||||||
|
def _callback(val):
|
||||||
|
"""
|
||||||
|
模型输出预测结果收集
|
||||||
|
通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束
|
||||||
|
结束条件包含如下
|
||||||
|
1、模型预测结束、收集器self.q队列收到 self.sentinel标识
|
||||||
|
2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件
|
||||||
|
3、模型预测出错
|
||||||
|
因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为
|
||||||
|
迭代器收集的行为如下
|
||||||
|
创建Iteratorize迭代对象,
|
||||||
|
定义generate_with_callback收集器AnswerResultStream
|
||||||
|
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
|
||||||
|
_generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体
|
||||||
|
由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测
|
||||||
|
这时generate_with_callback会被阻塞
|
||||||
|
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
|
||||||
|
1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理
|
||||||
|
2、消息为self.sentinel标识,抛出StopIteration异常
|
||||||
|
主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新
|
||||||
|
异步线程检测stop_now属性被更新,抛出异常结束预测行为
|
||||||
|
迭代行为结束
|
||||||
|
:param val:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.stop_now:
|
||||||
|
raise ValueError
|
||||||
|
self.q.put(val)
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
try:
|
||||||
|
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.q.put(self.sentinel)
|
||||||
|
|
||||||
|
self.thread = Thread(target=gen)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
obj = self.q.get(True, None)
|
||||||
|
if obj is self.sentinel:
|
||||||
|
raise StopIteration
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
暂无实现
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
""" break 后会执行 """
|
||||||
|
self.stop_now = True
|
||||||
|
|
||||||
|
|
||||||
class BaseAnswer(ABC):
|
class BaseAnswer(ABC):
|
||||||
"""上层业务包装器.用于结果生成统一api调用"""
|
"""上层业务包装器.用于结果生成统一api调用"""
|
||||||
|
|
||||||
@ -25,17 +155,23 @@ class BaseAnswer(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
"""Return _check_point of llm."""
|
"""Return _check_point of llm."""
|
||||||
|
def generatorAnswer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,) -> Generator[Any, str, bool]:
|
||||||
|
def generate_with_callback(callback=None, **kwargs):
|
||||||
|
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
|
||||||
|
self._generate_answer(**kwargs)
|
||||||
|
|
||||||
@property
|
def generate_with_streaming(**kwargs):
|
||||||
@abstractmethod
|
return Iteratorize(generate_with_callback, kwargs)
|
||||||
def _history_len(self) -> int:
|
|
||||||
"""Return _history_len of llm."""
|
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
|
||||||
|
for answerResult in generator:
|
||||||
|
yield answerResult
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_history_len(self, history_len: int) -> None:
|
def _generate_answer(self,
|
||||||
"""Return _history_len of llm."""
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
def generatorAnswer(self, prompt: str,
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -1,83 +1,117 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from langchain.llms.base import LLM
|
from langchain.chains.base import Chain
|
||||||
from typing import Optional, List
|
from typing import Any, Dict, List, Optional, Generator
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
# from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult)
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
# import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
class ChatGLM(BaseAnswer, LLM, ABC):
|
class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
|
||||||
max_token: int = 10000
|
max_token: int = 10000
|
||||||
temperature: float = 0.01
|
temperature: float = 0.01
|
||||||
top_p = 0.9
|
# 相关度
|
||||||
|
top_p = 0.4
|
||||||
|
# 候选词数量
|
||||||
|
top_k = 10
|
||||||
checkPoint: LoaderCheckPoint = None
|
checkPoint: LoaderCheckPoint = None
|
||||||
# history = []
|
# history = []
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
|
streaming_key: str = "streaming" #: :meta private:
|
||||||
|
history_key: str = "history" #: :meta private:
|
||||||
|
prompt_key: str = "prompt" #: :meta private:
|
||||||
|
output_key: str = "answer_result_stream" #: :meta private:
|
||||||
|
|
||||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkPoint = checkPoint
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "ChatGLM"
|
return "ChatGLMLLMChain"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
return self.checkPoint
|
return self.checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _history_len(self) -> int:
|
def input_keys(self) -> List[str]:
|
||||||
return self.history_len
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
def set_history_len(self, history_len: int = 10) -> None:
|
:meta private:
|
||||||
self.history_len = history_len
|
"""
|
||||||
|
return [self.prompt_key]
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Generator]:
|
||||||
|
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||||
|
return {self.output_key: generator}
|
||||||
|
|
||||||
|
def _generate_answer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
history = inputs[self.history_key]
|
||||||
|
streaming = inputs[self.streaming_key]
|
||||||
|
prompt = inputs[self.prompt_key]
|
||||||
print(f"__call:{prompt}")
|
print(f"__call:{prompt}")
|
||||||
response, _ = self.checkPoint.model.chat(
|
# Create the StoppingCriteriaList with the stopping strings
|
||||||
self.checkPoint.tokenizer,
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
prompt,
|
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||||
history=[],
|
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||||
max_length=self.max_token,
|
stopping_criteria_list.append(listenerQueue)
|
||||||
temperature=self.temperature
|
|
||||||
)
|
|
||||||
print(f"response:{response}")
|
|
||||||
print(f"+++++++++++++++++++++++++++++++++++")
|
|
||||||
return response
|
|
||||||
|
|
||||||
def generatorAnswer(self, prompt: str,
|
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
history += [[]]
|
history += [[]]
|
||||||
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
|
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
|
||||||
self.checkPoint.tokenizer,
|
self.checkPoint.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
history=history[-self.history_len:-1] if self.history_len > 1 else [],
|
history=history[-self.history_len:-1] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
stopping_criteria=stopping_criteria_list
|
||||||
)):
|
)):
|
||||||
# self.checkPoint.clear_torch_cache()
|
# self.checkPoint.clear_torch_cache()
|
||||||
history[-1] = [prompt, stream_resp]
|
history[-1] = [prompt, stream_resp]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": stream_resp}
|
answer_result.llm_output = {"answer": stream_resp}
|
||||||
yield answer_result
|
generate_with_callback(answer_result)
|
||||||
|
self.checkPoint.clear_torch_cache()
|
||||||
else:
|
else:
|
||||||
response, _ = self.checkPoint.model.chat(
|
response, _ = self.checkPoint.model.chat(
|
||||||
self.checkPoint.tokenizer,
|
self.checkPoint.tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
history=history[-self.history_len:] if self.history_len > 0 else [],
|
history=history[-self.history_len:] if self.history_len > 0 else [],
|
||||||
max_length=self.max_token,
|
max_length=self.max_token,
|
||||||
temperature=self.temperature
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
stopping_criteria=stopping_criteria_list
|
||||||
)
|
)
|
||||||
self.checkPoint.clear_torch_cache()
|
self.checkPoint.clear_torch_cache()
|
||||||
history += [[prompt, response]]
|
history += [[prompt, response]]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": response}
|
answer_result.llm_output = {"answer": response}
|
||||||
yield answer_result
|
|
||||||
|
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,37 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
import requests
|
from langchain.chains.base import Chain
|
||||||
from typing import Optional, List
|
from typing import (
|
||||||
from langchain.llms.base import LLM
|
Any, Dict, List, Optional, Generator, Collection, Set,
|
||||||
|
Callable,
|
||||||
|
Tuple,
|
||||||
|
Union)
|
||||||
|
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (RemoteRpcModel,
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
AnswerResult)
|
from models.base import (BaseAnswer,
|
||||||
from typing import (
|
RemoteRpcModel,
|
||||||
Collection,
|
AnswerResult,
|
||||||
Dict
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
)
|
)
|
||||||
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
|
from openai import (
|
||||||
|
ChatCompletion
|
||||||
|
)
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _build_message_template() -> Dict[str, str]:
|
def _build_message_template() -> Dict[str, str]:
|
||||||
@ -22,34 +44,88 @@ def _build_message_template() -> Dict[str, str]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
# 将历史对话数组转换为文本格式
|
||||||
|
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
|
||||||
|
build_messages: Collection[Dict[str, str]] = []
|
||||||
|
|
||||||
|
system_build_message = _build_message_template()
|
||||||
|
system_build_message['role'] = 'system'
|
||||||
|
system_build_message['content'] = "You are a helpful assistant."
|
||||||
|
build_messages.append(system_build_message)
|
||||||
|
if history:
|
||||||
|
for i, (user, assistant) in enumerate(history):
|
||||||
|
if user:
|
||||||
|
|
||||||
|
user_build_message = _build_message_template()
|
||||||
|
user_build_message['role'] = 'user'
|
||||||
|
user_build_message['content'] = user
|
||||||
|
build_messages.append(user_build_message)
|
||||||
|
|
||||||
|
if not assistant:
|
||||||
|
raise RuntimeError("历史数据结构不正确")
|
||||||
|
system_build_message = _build_message_template()
|
||||||
|
system_build_message['role'] = 'assistant'
|
||||||
|
system_build_message['content'] = assistant
|
||||||
|
build_messages.append(system_build_message)
|
||||||
|
|
||||||
|
user_build_message = _build_message_template()
|
||||||
|
user_build_message['role'] = 'user'
|
||||||
|
user_build_message['content'] = query
|
||||||
|
build_messages.append(user_build_message)
|
||||||
|
return build_messages
|
||||||
|
|
||||||
|
|
||||||
|
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
|
client: Any
|
||||||
|
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||||
|
max_retries: int = 6
|
||||||
api_base_url: str = "http://localhost:8000/v1"
|
api_base_url: str = "http://localhost:8000/v1"
|
||||||
model_name: str = "chatglm-6b"
|
model_name: str = "chatglm-6b"
|
||||||
max_token: int = 10000
|
max_token: int = 10000
|
||||||
temperature: float = 0.01
|
temperature: float = 0.01
|
||||||
top_p = 0.9
|
top_p = 0.9
|
||||||
checkPoint: LoaderCheckPoint = None
|
checkPoint: LoaderCheckPoint = None
|
||||||
history = []
|
# history = []
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
|
api_key: str = ""
|
||||||
|
|
||||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
streaming_key: str = "streaming" #: :meta private:
|
||||||
|
history_key: str = "history" #: :meta private:
|
||||||
|
prompt_key: str = "prompt" #: :meta private:
|
||||||
|
output_key: str = "answer_result_stream" #: :meta private:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
checkPoint: LoaderCheckPoint = None,
|
||||||
|
# api_base_url:str="http://localhost:8000/v1",
|
||||||
|
# model_name:str="chatglm-6b",
|
||||||
|
# api_key:str=""
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkPoint = checkPoint
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "FastChat"
|
return "LLamaLLMChain"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
return self.checkPoint
|
return self.checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _history_len(self) -> int:
|
def input_keys(self) -> List[str]:
|
||||||
return self.history_len
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
def set_history_len(self, history_len: int = 10) -> None:
|
:meta private:
|
||||||
self.history_len = history_len
|
"""
|
||||||
|
return [self.prompt_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _api_key(self) -> str:
|
def _api_key(self) -> str:
|
||||||
@ -60,7 +136,7 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
|||||||
return self.api_base_url
|
return self.api_base_url
|
||||||
|
|
||||||
def set_api_key(self, api_key: str):
|
def set_api_key(self, api_key: str):
|
||||||
pass
|
self.api_key = api_key
|
||||||
|
|
||||||
def set_api_base_url(self, api_base_url: str):
|
def set_api_base_url(self, api_base_url: str):
|
||||||
self.api_base_url = api_base_url
|
self.api_base_url = api_base_url
|
||||||
@ -68,70 +144,116 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
|
|||||||
def call_model_name(self, model_name):
|
def call_model_name(self, model_name):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
def _create_retry_decorator(self) -> Callable[[Any], Any]:
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(self.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion_with_retry(self, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = self._create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return self.client.create(**kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Generator]:
|
||||||
|
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||||
|
return {self.output_key: generator}
|
||||||
|
|
||||||
|
def _generate_answer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
|
history = inputs.get(self.history_key, [])
|
||||||
|
streaming = inputs.get(self.streaming_key, False)
|
||||||
|
prompt = inputs[self.prompt_key]
|
||||||
|
stop = inputs.get("stop", "stop")
|
||||||
print(f"__call:{prompt}")
|
print(f"__call:{prompt}")
|
||||||
try:
|
try:
|
||||||
import openai
|
|
||||||
# Not support yet
|
# Not support yet
|
||||||
openai.api_key = "EMPTY"
|
# openai.api_key = "EMPTY"
|
||||||
|
openai.api_key = self.api_key
|
||||||
openai.api_base = self.api_base_url
|
openai.api_base = self.api_base_url
|
||||||
except ImportError:
|
self.client = openai.ChatCompletion
|
||||||
|
except AttributeError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not import openai python package. "
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
"Please install it with `pip install openai`."
|
"due to an old version of the openai package. Try upgrading it "
|
||||||
|
"with `pip install --upgrade openai`."
|
||||||
)
|
)
|
||||||
# create a chat completion
|
msg = build_message_list(prompt, history=history)
|
||||||
completion = openai.ChatCompletion.create(
|
|
||||||
model=self.model_name,
|
|
||||||
messages=self.build_message_list(prompt)
|
|
||||||
)
|
|
||||||
print(f"response:{completion.choices[0].message.content}")
|
|
||||||
print(f"+++++++++++++++++++++++++++++++++++")
|
|
||||||
return completion.choices[0].message.content
|
|
||||||
|
|
||||||
# 将历史对话数组转换为文本格式
|
if streaming:
|
||||||
def build_message_list(self, query) -> Collection[Dict[str, str]]:
|
params = {"stream": streaming,
|
||||||
build_message_list: Collection[Dict[str, str]] = []
|
"model": self.model_name,
|
||||||
history = self.history[-self.history_len:] if self.history_len > 0 else []
|
"stop": stop}
|
||||||
for i, (old_query, response) in enumerate(history):
|
out_str = ""
|
||||||
user_build_message = _build_message_template()
|
for stream_resp in self.completion_with_retry(
|
||||||
user_build_message['role'] = 'user'
|
messages=msg,
|
||||||
user_build_message['content'] = old_query
|
**params
|
||||||
system_build_message = _build_message_template()
|
):
|
||||||
system_build_message['role'] = 'system'
|
role = stream_resp["choices"][0]["delta"].get("role", "")
|
||||||
system_build_message['content'] = response
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
build_message_list.append(user_build_message)
|
out_str += token
|
||||||
build_message_list.append(system_build_message)
|
history[-1] = [prompt, out_str]
|
||||||
|
answer_result = AnswerResult()
|
||||||
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": out_str}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
else:
|
||||||
|
|
||||||
user_build_message = _build_message_template()
|
params = {"stream": streaming,
|
||||||
user_build_message['role'] = 'user'
|
"model": self.model_name,
|
||||||
user_build_message['content'] = query
|
"stop": stop}
|
||||||
build_message_list.append(user_build_message)
|
response = self.completion_with_retry(
|
||||||
return build_message_list
|
messages=msg,
|
||||||
|
**params
|
||||||
def generatorAnswer(self, prompt: str,
|
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
|
|
||||||
try:
|
|
||||||
import openai
|
|
||||||
# Not support yet
|
|
||||||
openai.api_key = "EMPTY"
|
|
||||||
openai.api_base = self.api_base_url
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import openai python package. "
|
|
||||||
"Please install it with `pip install openai`."
|
|
||||||
)
|
)
|
||||||
# create a chat completion
|
role = response["choices"][0]["message"].get("role", "")
|
||||||
completion = openai.ChatCompletion.create(
|
content = response["choices"][0]["message"].get("content", "")
|
||||||
model=self.model_name,
|
history += [[prompt, content]]
|
||||||
messages=self.build_message_list(prompt)
|
answer_result = AnswerResult()
|
||||||
)
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": content}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
history += [[prompt, completion.choices[0].message.content]]
|
|
||||||
answer_result = AnswerResult()
|
|
||||||
answer_result.history = history
|
|
||||||
answer_result.llm_output = {"answer": completion.choices[0].message.content}
|
|
||||||
|
|
||||||
yield answer_result
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
chain = FastChatOpenAILLMChain()
|
||||||
|
|
||||||
|
chain.set_api_key("EMPTY")
|
||||||
|
# chain.set_api_base_url("https://api.openai.com/v1")
|
||||||
|
# chain.call_model_name("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
answer_result_stream_result = chain({"streaming": True,
|
||||||
|
"prompt": "你好",
|
||||||
|
"history": []
|
||||||
|
})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
|
resp = answer_result.llm_output["answer"]
|
||||||
|
print(resp)
|
||||||
|
|||||||
@ -1,26 +1,32 @@
|
|||||||
from abc import ABC
|
|
||||||
|
|
||||||
from langchain.llms.base import LLM
|
from abc import ABC
|
||||||
import random
|
from langchain.chains.base import Chain
|
||||||
import torch
|
from typing import Any, Dict, List, Optional, Generator, Union
|
||||||
import transformers
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult)
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: Union[torch.LongTensor, list],
|
||||||
|
scores: Union[torch.FloatTensor, list]) -> torch.FloatTensor:
|
||||||
|
# llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor
|
||||||
|
input_ids = torch.tensor(input_ids) if isinstance(input_ids, list) else input_ids
|
||||||
|
scores = torch.tensor(scores) if isinstance(scores, list) else scores
|
||||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||||
scores.zero_()
|
scores.zero_()
|
||||||
scores[..., 5] = 5e4
|
scores[..., 5] = 5e4
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class LLamaLLM(BaseAnswer, LLM, ABC):
|
class LLamaLLMChain(BaseAnswer, Chain, ABC):
|
||||||
checkPoint: LoaderCheckPoint = None
|
checkPoint: LoaderCheckPoint = None
|
||||||
# history = []
|
# history = []
|
||||||
history_len: int = 3
|
history_len: int = 3
|
||||||
@ -34,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||||||
min_length: int = 0
|
min_length: int = 0
|
||||||
logits_processor: LogitsProcessorList = None
|
logits_processor: LogitsProcessorList = None
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None
|
stopping_criteria: Optional[StoppingCriteriaList] = None
|
||||||
eos_token_id: Optional[int] = [2]
|
streaming_key: str = "streaming" #: :meta private:
|
||||||
|
history_key: str = "history" #: :meta private:
|
||||||
state: object = {'max_new_tokens': 50,
|
prompt_key: str = "prompt" #: :meta private:
|
||||||
'seed': 1,
|
output_key: str = "answer_result_stream" #: :meta private:
|
||||||
'temperature': 0, 'top_p': 0.1,
|
|
||||||
'top_k': 40, 'typical_p': 1,
|
|
||||||
'repetition_penalty': 1.2,
|
|
||||||
'encoder_repetition_penalty': 1,
|
|
||||||
'no_repeat_ngram_size': 0,
|
|
||||||
'min_length': 0,
|
|
||||||
'penalty_alpha': 0,
|
|
||||||
'num_beams': 1,
|
|
||||||
'length_penalty': 1,
|
|
||||||
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
|
|
||||||
'truncation_length': 2048, 'custom_stopping_strings': '',
|
|
||||||
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
|
|
||||||
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
|
|
||||||
'pre_layer': 0, 'gpu_memory_0': 0}
|
|
||||||
|
|
||||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkPoint = checkPoint
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "LLamaLLM"
|
return "LLamaLLMChain"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.prompt_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
@ -104,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||||||
formatted_history += "### Human:{}\n### Assistant:".format(query)
|
formatted_history += "### Human:{}\n### Assistant:".format(query)
|
||||||
return formatted_history
|
return formatted_history
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self,
|
def _call(
|
||||||
input_ids: torch.LongTensor):
|
self,
|
||||||
"""
|
inputs: Dict[str, Any],
|
||||||
预生成注意力掩码和 输入序列中每个位置的索引的张量
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
# TODO 没有思路
|
) -> Dict[str, Generator]:
|
||||||
:return:
|
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||||
"""
|
return {self.output_key: generator}
|
||||||
|
|
||||||
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
|
def _generate_answer(self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
attention_mask = self.get_masks(input_ids, input_ids.device)
|
history = inputs[self.history_key]
|
||||||
|
streaming = inputs[self.streaming_key]
|
||||||
position_ids = self.get_position_ids(
|
prompt = inputs[self.prompt_key]
|
||||||
input_ids,
|
|
||||||
device=input_ids.device,
|
|
||||||
mask_positions=mask_positions
|
|
||||||
)
|
|
||||||
|
|
||||||
return input_ids, position_ids, attention_mask
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _history_len(self) -> int:
|
|
||||||
return self.history_len
|
|
||||||
|
|
||||||
def set_history_len(self, history_len: int = 10) -> None:
|
|
||||||
self.history_len = history_len
|
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
||||||
print(f"__call:{prompt}")
|
print(f"__call:{prompt}")
|
||||||
|
|
||||||
|
# Create the StoppingCriteriaList with the stopping strings
|
||||||
|
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||||
|
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
|
||||||
|
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
|
||||||
|
self.stopping_criteria.append(listenerQueue)
|
||||||
|
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
|
||||||
|
soft_prompt = self.history_to_text(query=prompt, history=history)
|
||||||
if self.logits_processor is None:
|
if self.logits_processor is None:
|
||||||
self.logits_processor = LogitsProcessorList()
|
self.logits_processor = LogitsProcessorList()
|
||||||
self.logits_processor.append(InvalidScoreLogitsProcessor())
|
self.logits_processor.append(InvalidScoreLogitsProcessor())
|
||||||
@ -151,35 +155,36 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
|||||||
"logits_processor": self.logits_processor}
|
"logits_processor": self.logits_processor}
|
||||||
|
|
||||||
# 向量转换
|
# 向量转换
|
||||||
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
|
input_ids = self.encode(soft_prompt, add_bos_token=self.checkPoint.tokenizer.add_bos_token,
|
||||||
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
|
truncation_length=self.max_new_tokens)
|
||||||
|
|
||||||
|
|
||||||
gen_kwargs.update({'inputs': input_ids})
|
gen_kwargs.update({'inputs': input_ids})
|
||||||
# 注意力掩码
|
|
||||||
# gen_kwargs.update({'attention_mask': attention_mask})
|
|
||||||
# gen_kwargs.update({'position_ids': position_ids})
|
|
||||||
if self.stopping_criteria is None:
|
|
||||||
self.stopping_criteria = transformers.StoppingCriteriaList()
|
|
||||||
# 观测输出
|
# 观测输出
|
||||||
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
||||||
|
# llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误
|
||||||
|
# 因此需要先判断模型是否是llama-cpp模型,然后取gen_kwargs与模型generate方法字段的交集
|
||||||
|
# 仅将交集字段传给模型以保证兼容性
|
||||||
|
# todo llama-cpp模型在本框架下兼容性较差,后续可以考虑重写一个llama_cpp_llm.py模块
|
||||||
|
if "llama_cpp" in self.checkPoint.model.__str__():
|
||||||
|
import inspect
|
||||||
|
|
||||||
output_ids = self.checkPoint.model.generate(**gen_kwargs)
|
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args) & set(
|
||||||
|
gen_kwargs.keys())
|
||||||
|
common_kwargs = {key: gen_kwargs[key] for key in common_kwargs_keys}
|
||||||
|
# ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣
|
||||||
|
# ?为什么会不支持GPU呢,不应该啊?
|
||||||
|
output_ids = torch.tensor(
|
||||||
|
[list(self.checkPoint.model.generate(input_id_i.cpu(), **common_kwargs)) for input_id_i in input_ids])
|
||||||
|
|
||||||
|
else:
|
||||||
|
output_ids = self.checkPoint.model.generate(**gen_kwargs)
|
||||||
new_tokens = len(output_ids[0]) - len(input_ids[0])
|
new_tokens = len(output_ids[0]) - len(input_ids[0])
|
||||||
reply = self.decode(output_ids[0][-new_tokens:])
|
reply = self.decode(output_ids[0][-new_tokens:])
|
||||||
print(f"response:{reply}")
|
print(f"response:{reply}")
|
||||||
print(f"+++++++++++++++++++++++++++++++++++")
|
print(f"+++++++++++++++++++++++++++++++++++")
|
||||||
return reply
|
|
||||||
|
|
||||||
def generatorAnswer(self, prompt: str,
|
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
|
|
||||||
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现
|
|
||||||
softprompt = self.history_to_text(prompt,history=history)
|
|
||||||
response = self._call(prompt=softprompt, stop=['\n###'])
|
|
||||||
|
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history + [[prompt, response]]
|
history += [[prompt, reply]]
|
||||||
answer_result.llm_output = {"answer": response}
|
answer_result.history = history
|
||||||
yield answer_result
|
answer_result.llm_output = {"answer": reply}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from configs.model_config import *
|
from configs.model_config import *
|
||||||
@ -43,7 +44,8 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
|
|||||||
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
||||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||||
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
||||||
|
parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint")
|
||||||
|
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
|
||||||
# Accelerate/transformers
|
# Accelerate/transformers
|
||||||
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
||||||
help='Load the model with 8-bit precision.')
|
help='Load the model with 8-bit precision.')
|
||||||
|
|||||||
@ -20,6 +20,7 @@ class LoaderCheckPoint:
|
|||||||
no_remote_model: bool = False
|
no_remote_model: bool = False
|
||||||
# 模型名称
|
# 模型名称
|
||||||
model_name: str = None
|
model_name: str = None
|
||||||
|
pretrained_model_name: str = None
|
||||||
tokenizer: object = None
|
tokenizer: object = None
|
||||||
# 模型全路径
|
# 模型全路径
|
||||||
model_path: str = None
|
model_path: str = None
|
||||||
@ -35,11 +36,11 @@ class LoaderCheckPoint:
|
|||||||
# 因此主要的解决思路是清理环境变量里PATH下的不匹配的cuda版本,一劳永逸的方法是:
|
# 因此主要的解决思路是清理环境变量里PATH下的不匹配的cuda版本,一劳永逸的方法是:
|
||||||
# 0. 在终端执行`pip uninstall bitsandbytes`
|
# 0. 在终端执行`pip uninstall bitsandbytes`
|
||||||
# 1. 删除.bashrc文件下关于PATH的条目
|
# 1. 删除.bashrc文件下关于PATH的条目
|
||||||
# 2. 在终端执行 `echo $PATH >> .bashrc`
|
# 2. 在终端执行 `echo $PATH >> .bashrc`
|
||||||
# 3. 删除.bashrc文件下PATH中关于不匹配的cuda版本路径
|
# 3. 删除.bashrc文件下PATH中关于不匹配的cuda版本路径
|
||||||
# 4. 在终端执行`source .bashrc`
|
# 4. 在终端执行`source .bashrc`
|
||||||
# 5. 再执行`pip install bitsandbytes`
|
# 5. 再执行`pip install bitsandbytes`
|
||||||
|
|
||||||
load_in_8bit: bool = False
|
load_in_8bit: bool = False
|
||||||
is_llamacpp: bool = False
|
is_llamacpp: bool = False
|
||||||
bf16: bool = False
|
bf16: bool = False
|
||||||
@ -67,43 +68,49 @@ class LoaderCheckPoint:
|
|||||||
self.load_in_8bit = params.get('load_in_8bit', False)
|
self.load_in_8bit = params.get('load_in_8bit', False)
|
||||||
self.bf16 = params.get('bf16', False)
|
self.bf16 = params.get('bf16', False)
|
||||||
|
|
||||||
def _load_model_config(self, model_name):
|
def _load_model_config(self):
|
||||||
|
|
||||||
if self.model_path:
|
if self.model_path:
|
||||||
|
self.model_path = re.sub("\s", "", self.model_path)
|
||||||
checkpoint = Path(f'{self.model_path}')
|
checkpoint = Path(f'{self.model_path}')
|
||||||
else:
|
else:
|
||||||
if not self.no_remote_model:
|
if self.no_remote_model:
|
||||||
checkpoint = model_name
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"本地模型local_model_path未配置路径"
|
"本地模型local_model_path未配置路径"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
checkpoint = self.pretrained_model_name
|
||||||
|
|
||||||
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
print(f"load_model_config {checkpoint}...")
|
||||||
|
try:
|
||||||
|
|
||||||
return model_config
|
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
|
||||||
|
return model_config
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
def _load_model(self, model_name):
|
def _load_model(self):
|
||||||
"""
|
"""
|
||||||
加载自定义位置的model
|
加载自定义位置的model
|
||||||
:param model_name:
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
print(f"Loading {model_name}...")
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
if self.model_path:
|
if self.model_path:
|
||||||
|
self.model_path = re.sub("\s", "", self.model_path)
|
||||||
checkpoint = Path(f'{self.model_path}')
|
checkpoint = Path(f'{self.model_path}')
|
||||||
else:
|
else:
|
||||||
if not self.no_remote_model:
|
if self.no_remote_model:
|
||||||
checkpoint = model_name
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"本地模型local_model_path未配置路径"
|
"本地模型local_model_path未配置路径"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
checkpoint = self.pretrained_model_name
|
||||||
|
|
||||||
|
print(f"Loading {checkpoint}...")
|
||||||
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
|
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
|
||||||
if 'chatglm' in model_name.lower():
|
if 'chatglm' in self.model_name.lower() or "chatyuan" in self.model_name.lower():
|
||||||
LoaderClass = AutoModel
|
LoaderClass = AutoModel
|
||||||
else:
|
else:
|
||||||
LoaderClass = AutoModelForCausalLM
|
LoaderClass = AutoModelForCausalLM
|
||||||
@ -126,8 +133,14 @@ class LoaderCheckPoint:
|
|||||||
.half()
|
.half()
|
||||||
.cuda()
|
.cuda()
|
||||||
)
|
)
|
||||||
|
# 支持自定义cuda设备
|
||||||
|
elif ":" in self.llm_device:
|
||||||
|
model = LoaderClass.from_pretrained(checkpoint,
|
||||||
|
config=self.model_config,
|
||||||
|
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
|
||||||
|
trust_remote_code=True).half().to(self.llm_device)
|
||||||
else:
|
else:
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model, infer_auto_device_map
|
||||||
|
|
||||||
model = LoaderClass.from_pretrained(checkpoint,
|
model = LoaderClass.from_pretrained(checkpoint,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@ -135,12 +148,22 @@ class LoaderCheckPoint:
|
|||||||
trust_remote_code=True).half()
|
trust_remote_code=True).half()
|
||||||
# 可传入device_map自定义每张卡的部署情况
|
# 可传入device_map自定义每张卡的部署情况
|
||||||
if self.device_map is None:
|
if self.device_map is None:
|
||||||
if 'chatglm' in model_name.lower():
|
if 'chatglm' in self.model_name.lower() and not "chatglm2" in self.model_name.lower():
|
||||||
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
||||||
elif 'moss' in model_name.lower():
|
elif 'moss' in self.model_name.lower():
|
||||||
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
|
self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
|
||||||
else:
|
else:
|
||||||
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
|
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
|
||||||
|
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡
|
||||||
|
from accelerate.utils import get_balanced_memory
|
||||||
|
max_memory = get_balanced_memory(model,
|
||||||
|
dtype=torch.int8 if self.load_in_8bit else None,
|
||||||
|
low_zero=False,
|
||||||
|
no_split_module_classes=model._no_split_modules)
|
||||||
|
self.device_map = infer_auto_device_map(model,
|
||||||
|
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
|
||||||
|
max_memory=max_memory,
|
||||||
|
no_split_module_classes=model._no_split_modules)
|
||||||
|
|
||||||
model = dispatch_model(model, device_map=self.device_map)
|
model = dispatch_model(model, device_map=self.device_map)
|
||||||
else:
|
else:
|
||||||
@ -156,7 +179,7 @@ class LoaderCheckPoint:
|
|||||||
elif self.is_llamacpp:
|
elif self.is_llamacpp:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from models.extensions.llamacpp_model_alternative import LlamaCppModel
|
from llama_cpp import Llama
|
||||||
|
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -167,7 +190,16 @@ class LoaderCheckPoint:
|
|||||||
model_file = list(checkpoint.glob('ggml*.bin'))[0]
|
model_file = list(checkpoint.glob('ggml*.bin'))[0]
|
||||||
print(f"llama.cpp weights detected: {model_file}\n")
|
print(f"llama.cpp weights detected: {model_file}\n")
|
||||||
|
|
||||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
model = Llama(model_path=model_file._str)
|
||||||
|
|
||||||
|
# 实测llama-cpp-vicuna13b-q5_1的AutoTokenizer加载tokenizer的速度极慢,应存在优化空间
|
||||||
|
# 但需要对huggingface的AutoTokenizer进行优化
|
||||||
|
|
||||||
|
# tokenizer = model.tokenizer
|
||||||
|
# todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容
|
||||||
|
# * -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
elif self.load_in_8bit:
|
elif self.load_in_8bit:
|
||||||
@ -194,7 +226,7 @@ class LoaderCheckPoint:
|
|||||||
llm_int8_enable_fp32_cpu_offload=False)
|
llm_int8_enable_fp32_cpu_offload=False)
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = LoaderClass.from_config(self.model_config,trust_remote_code = True)
|
model = LoaderClass.from_config(self.model_config, trust_remote_code=True)
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
if self.device_map is not None:
|
if self.device_map is not None:
|
||||||
params['device_map'] = self.device_map
|
params['device_map'] = self.device_map
|
||||||
@ -257,10 +289,21 @@ class LoaderCheckPoint:
|
|||||||
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
||||||
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
||||||
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
||||||
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
|
||||||
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
|
||||||
f'base_model.model.lm_head': 0, }
|
|
||||||
|
|
||||||
|
encode = ""
|
||||||
|
if 'chatglm2' in self.model_name:
|
||||||
|
device_map = {
|
||||||
|
f"{layer_prefix}.embedding.word_embeddings": 0,
|
||||||
|
f"{layer_prefix}.rotary_pos_emb": 0,
|
||||||
|
f"{layer_prefix}.output_layer": 0,
|
||||||
|
f"{layer_prefix}.encoder.final_layernorm": 0,
|
||||||
|
f"base_model.model.output_layer": 0
|
||||||
|
}
|
||||||
|
encode = ".encoder"
|
||||||
|
else:
|
||||||
|
device_map = {f'{layer_prefix}.word_embeddings': 0,
|
||||||
|
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
|
||||||
|
f'base_model.model.lm_head': 0, }
|
||||||
used = 2
|
used = 2
|
||||||
gpu_target = 0
|
gpu_target = 0
|
||||||
for i in range(num_trans_layers):
|
for i in range(num_trans_layers):
|
||||||
@ -268,12 +311,12 @@ class LoaderCheckPoint:
|
|||||||
gpu_target += 1
|
gpu_target += 1
|
||||||
used = 0
|
used = 0
|
||||||
assert gpu_target < num_gpus
|
assert gpu_target < num_gpus
|
||||||
device_map[f'{layer_prefix}.layers.{i}'] = gpu_target
|
device_map[f'{layer_prefix}{encode}.layers.{i}'] = gpu_target
|
||||||
used += 1
|
used += 1
|
||||||
|
|
||||||
return device_map
|
return device_map
|
||||||
|
|
||||||
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]:
|
def moss_auto_configure_device_map(self, num_gpus: int, checkpoint) -> Dict[str, int]:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
@ -288,16 +331,6 @@ class LoaderCheckPoint:
|
|||||||
"`pip install bitsandbytes``pip install accelerate`."
|
"`pip install bitsandbytes``pip install accelerate`."
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
if self.model_path:
|
|
||||||
checkpoint = Path(f'{self.model_path}')
|
|
||||||
else:
|
|
||||||
if not self.no_remote_model:
|
|
||||||
checkpoint = model_name
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"本地模型local_model_path未配置路径"
|
|
||||||
)
|
|
||||||
|
|
||||||
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
|
||||||
pretrained_model_name_or_path=checkpoint)
|
pretrained_model_name_or_path=checkpoint)
|
||||||
|
|
||||||
@ -385,7 +418,7 @@ class LoaderCheckPoint:
|
|||||||
print(
|
print(
|
||||||
"如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
|
"如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
|
||||||
elif torch.has_cuda:
|
elif torch.has_cuda:
|
||||||
device_id = "0" if torch.cuda.is_available() else None
|
device_id = "0" if torch.cuda.is_available() and (":" not in self.llm_device) else None
|
||||||
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
|
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
|
||||||
with torch.cuda.device(CUDA_DEVICE):
|
with torch.cuda.device(CUDA_DEVICE):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -404,33 +437,37 @@ class LoaderCheckPoint:
|
|||||||
|
|
||||||
def reload_model(self):
|
def reload_model(self):
|
||||||
self.unload_model()
|
self.unload_model()
|
||||||
self.model_config = self._load_model_config(self.model_name)
|
self.model_config = self._load_model_config()
|
||||||
|
|
||||||
if self.use_ptuning_v2:
|
if self.use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
|
prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
|
||||||
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
prefix_encoder_config = json.loads(prefix_encoder_file.read())
|
||||||
prefix_encoder_file.close()
|
prefix_encoder_file.close()
|
||||||
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
|
||||||
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print("加载PrefixEncoder config.json失败")
|
print("加载PrefixEncoder config.json失败")
|
||||||
|
|
||||||
self.model, self.tokenizer = self._load_model(self.model_name)
|
self.model, self.tokenizer = self._load_model()
|
||||||
|
|
||||||
if self.lora:
|
if self.lora:
|
||||||
self._add_lora_to_model([self.lora])
|
self._add_lora_to_model([self.lora])
|
||||||
|
|
||||||
if self.use_ptuning_v2:
|
if self.use_ptuning_v2:
|
||||||
try:
|
try:
|
||||||
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
|
prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
|
||||||
new_prefix_state_dict = {}
|
new_prefix_state_dict = {}
|
||||||
for k, v in prefix_state_dict.items():
|
for k, v in prefix_state_dict.items():
|
||||||
if k.startswith("transformer.prefix_encoder."):
|
if k.startswith("transformer.prefix_encoder."):
|
||||||
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
||||||
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
||||||
self.model.transformer.prefix_encoder.float()
|
self.model.transformer.prefix_encoder.float()
|
||||||
|
print("加载ptuning检查点成功!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print("加载PrefixEncoder模型参数失败")
|
print("加载PrefixEncoder模型参数失败")
|
||||||
|
# llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法
|
||||||
self.model = self.model.eval()
|
if not self.is_llamacpp:
|
||||||
|
self.model = self.model.eval()
|
||||||
|
|||||||
@ -1,12 +1,20 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from langchain.llms.base import LLM
|
from langchain.chains.base import Chain
|
||||||
from typing import Optional, List
|
from typing import Any, Dict, List, Optional, Generator, Union
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult)
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# todo 建议重写instruction,在该instruction下,各模型的表现比较差
|
||||||
META_INSTRUCTION = \
|
META_INSTRUCTION = \
|
||||||
"""You are an AI assistant whose name is MOSS.
|
"""You are an AI assistant whose name is MOSS.
|
||||||
- MOSS is a conversational language model that is developed by Fudan University. It is designed to be helpful, honest, and harmless.
|
- MOSS is a conversational language model that is developed by Fudan University. It is designed to be helpful, honest, and harmless.
|
||||||
@ -21,49 +29,76 @@ META_INSTRUCTION = \
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MOSSLLM(BaseAnswer, LLM, ABC):
|
# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因
|
||||||
|
class MOSSLLMChain(BaseAnswer, Chain, ABC):
|
||||||
max_token: int = 2048
|
max_token: int = 2048
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
top_p = 0.8
|
top_p = 0.8
|
||||||
# history = []
|
# history = []
|
||||||
checkPoint: LoaderCheckPoint = None
|
checkPoint: LoaderCheckPoint = None
|
||||||
history_len: int = 10
|
history_len: int = 10
|
||||||
|
streaming_key: str = "streaming" #: :meta private:
|
||||||
|
history_key: str = "history" #: :meta private:
|
||||||
|
prompt_key: str = "prompt" #: :meta private:
|
||||||
|
output_key: str = "answer_result_stream" #: :meta private:
|
||||||
|
|
||||||
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkPoint = checkPoint
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "MOSS"
|
return "MOSSLLMChain"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.prompt_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Will always return text key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _check_point(self) -> LoaderCheckPoint:
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
return self.checkPoint
|
return self.checkPoint
|
||||||
|
|
||||||
@property
|
def _call(
|
||||||
def set_history_len(self) -> int:
|
self,
|
||||||
return self.history_len
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Generator]:
|
||||||
|
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
|
||||||
|
return {self.output_key: generator}
|
||||||
|
|
||||||
def _set_history_len(self, history_len: int) -> None:
|
def _generate_answer(self,
|
||||||
self.history_len = history_len
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
history = inputs[self.history_key]
|
||||||
pass
|
streaming = inputs[self.streaming_key]
|
||||||
|
prompt = inputs[self.prompt_key]
|
||||||
def generatorAnswer(self, prompt: str,
|
print(f"__call:{prompt}")
|
||||||
history: List[List[str]] = [],
|
|
||||||
streaming: bool = False):
|
|
||||||
if len(history) > 0:
|
if len(history) > 0:
|
||||||
history = history[-self.history_len:] if self.history_len > 0 else []
|
history = history[-self.history_len:] if self.history_len > 0 else []
|
||||||
prompt_w_history = str(history)
|
prompt_w_history = str(history)
|
||||||
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
|
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
|
||||||
else:
|
else:
|
||||||
prompt_w_history = META_INSTRUCTION
|
prompt_w_history = META_INSTRUCTION.replace("MOSS", self.checkPoint.model_name.split("/")[-1])
|
||||||
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
|
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
|
||||||
|
|
||||||
inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt")
|
inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# max_length似乎可以设的小一些,而repetion_penalty应大一些,否则chatyuan,bloom等模型为满足max会重复输出
|
||||||
|
#
|
||||||
outputs = self.checkPoint.model.generate(
|
outputs = self.checkPoint.model.generate(
|
||||||
inputs.input_ids.cuda(),
|
inputs.input_ids.cuda(),
|
||||||
attention_mask=inputs.attention_mask.cuda(),
|
attention_mask=inputs.attention_mask.cuda(),
|
||||||
@ -76,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
|
|||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
eos_token_id=106068,
|
eos_token_id=106068,
|
||||||
pad_token_id=self.checkPoint.tokenizer.pad_token_id)
|
pad_token_id=self.checkPoint.tokenizer.pad_token_id)
|
||||||
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:],
|
||||||
|
skip_special_tokens=True)
|
||||||
self.checkPoint.clear_torch_cache()
|
self.checkPoint.clear_torch_cache()
|
||||||
history += [[prompt, response]]
|
history += [[prompt, response]]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": response}
|
answer_result.llm_output = {"answer": response}
|
||||||
|
|
||||||
yield answer_result
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
|
|||||||
if use_ptuning_v2:
|
if use_ptuning_v2:
|
||||||
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
|
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
|
||||||
|
|
||||||
|
# 如果指定了参数,则使用参数的配置
|
||||||
if llm_model:
|
if llm_model:
|
||||||
llm_model_info = llm_model_dict[llm_model]
|
llm_model_info = llm_model_dict[llm_model]
|
||||||
|
|
||||||
if loaderCheckPoint.no_remote_model:
|
loaderCheckPoint.model_name = llm_model_info['name']
|
||||||
loaderCheckPoint.model_name = llm_model_info['name']
|
loaderCheckPoint.pretrained_model_name = llm_model_info['pretrained_model_name']
|
||||||
else:
|
|
||||||
loaderCheckPoint.model_name = llm_model_info['pretrained_model_name']
|
|
||||||
|
|
||||||
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
|
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
|
||||||
|
|
||||||
@ -44,4 +43,5 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
|
|||||||
if 'FastChatOpenAILLM' in llm_model_info["provides"]:
|
if 'FastChatOpenAILLM' in llm_model_info["provides"]:
|
||||||
modelInsLLM.set_api_base_url(llm_model_info['api_base_url'])
|
modelInsLLM.set_api_base_url(llm_model_info['api_base_url'])
|
||||||
modelInsLLM.call_model_name(llm_model_info['name'])
|
modelInsLLM.call_model_name(llm_model_info['name'])
|
||||||
|
modelInsLLM.set_api_key(llm_model_info['api_key'])
|
||||||
return modelInsLLM
|
return modelInsLLM
|
||||||
|
|||||||
@ -11,7 +11,7 @@ beautifulsoup4
|
|||||||
icetk
|
icetk
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
faiss-cpu
|
faiss-cpu
|
||||||
gradio==3.28.3
|
gradio==3.37.0
|
||||||
fastapi~=0.95.0
|
fastapi~=0.95.0
|
||||||
uvicorn~=0.21.1
|
uvicorn~=0.21.1
|
||||||
pypinyin~=0.48.0
|
pypinyin~=0.48.0
|
||||||
@ -23,9 +23,13 @@ openai
|
|||||||
#accelerate~=0.18.0
|
#accelerate~=0.18.0
|
||||||
#peft~=0.3.0
|
#peft~=0.3.0
|
||||||
#bitsandbytes; platform_system != "Windows"
|
#bitsandbytes; platform_system != "Windows"
|
||||||
#llama-cpp-python==0.1.34; platform_system != "Windows"
|
|
||||||
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
|
||||||
|
|
||||||
|
# 要调用llama-cpp模型,如vicuma-13b量化模型需要安装llama-cpp-python库
|
||||||
|
# but!!! 实测pip install 不好使,需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载
|
||||||
|
# 而且注意不同时期的ggml格式并不!兼!容!!!因此需要安装的llama-cpp-python版本也不一致,需要手动测试才能确定
|
||||||
|
# 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容
|
||||||
|
# 不过!!!本项目模型加载的方式控制的比较严格,与llama-cpp-python的兼容性较差,很多参数设定不能使用,
|
||||||
|
# 建议如非必要还是不要使用llama-cpp
|
||||||
torch~=2.0.0
|
torch~=2.0.0
|
||||||
pydantic~=1.10.7
|
pydantic~=1.10.7
|
||||||
starlette~=0.26.1
|
starlette~=0.26.1
|
||||||
|
|||||||
@ -1,39 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
|
|
||||||
import asyncio
|
|
||||||
from argparse import Namespace
|
|
||||||
from models.loader.args import parser
|
|
||||||
from models.loader import LoaderCheckPoint
|
|
||||||
|
|
||||||
|
|
||||||
import models.shared as shared
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch(args: Namespace):
|
|
||||||
args_dict = vars(args)
|
|
||||||
|
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
|
||||||
|
|
||||||
llm_model_ins = shared.loaderLLM()
|
|
||||||
|
|
||||||
history = [
|
|
||||||
("which city is this?", "tokyo"),
|
|
||||||
("why?", "she's japanese"),
|
|
||||||
|
|
||||||
]
|
|
||||||
for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history,
|
|
||||||
streaming=False):
|
|
||||||
resp = answer_result.llm_output["answer"]
|
|
||||||
|
|
||||||
print(resp)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = None
|
|
||||||
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'fastchat-chatglm-6b', '--no-remote-model'])
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_until_complete(dispatch(args))
|
|
||||||
@ -16,6 +16,24 @@ export const chatfile = (params: any) => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const getKbsList = () => {
|
||||||
|
return api({
|
||||||
|
url: '/local_doc_qa/list_knowledge_base',
|
||||||
|
method: 'get',
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export const deleteKb = (knowledge_base_id: any) => {
|
||||||
|
return api({
|
||||||
|
url: '/local_doc_qa/delete_knowledge_base',
|
||||||
|
method: 'delete',
|
||||||
|
params: {
|
||||||
|
knowledge_base_id,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
export const getfilelist = (knowledge_base_id: any) => {
|
export const getfilelist = (knowledge_base_id: any) => {
|
||||||
return api({
|
return api({
|
||||||
url: '/local_doc_qa/list_files',
|
url: '/local_doc_qa/list_files',
|
||||||
@ -35,8 +53,8 @@ export const bing_search = (params: any) => {
|
|||||||
export const deletefile = (params: any) => {
|
export const deletefile = (params: any) => {
|
||||||
return api({
|
return api({
|
||||||
url: '/local_doc_qa/delete_file',
|
url: '/local_doc_qa/delete_file',
|
||||||
method: 'post',
|
method: 'delete',
|
||||||
data: JSON.stringify(params),
|
params,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
export const web_url = () => {
|
export const web_url = () => {
|
||||||
|
|||||||
@ -555,7 +555,7 @@ const options = computed(() => {
|
|||||||
|
|
||||||
return common
|
return common
|
||||||
})
|
})
|
||||||
function handleSelect(key: 'copyText' | 'delete' | 'toggleRenderType') {
|
function handleSelect(key: string) {
|
||||||
if (key == '清除会话') {
|
if (key == '清除会话') {
|
||||||
handleClear()
|
handleClear()
|
||||||
}
|
}
|
||||||
@ -658,7 +658,6 @@ function searchfun() {
|
|||||||
<NDropdown
|
<NDropdown
|
||||||
v-if="isMobile"
|
v-if="isMobile"
|
||||||
:trigger="isMobile ? 'click' : 'hover'"
|
:trigger="isMobile ? 'click' : 'hover'"
|
||||||
:placement="!inversion ? 'right' : 'left'"
|
|
||||||
:options="options"
|
:options="options"
|
||||||
@select="handleSelect"
|
@select="handleSelect"
|
||||||
>
|
>
|
||||||
|
|||||||
@ -3,15 +3,16 @@ import { NButton, NForm, NFormItem, NInput, NPopconfirm } from 'naive-ui'
|
|||||||
import { onMounted, ref } from 'vue'
|
import { onMounted, ref } from 'vue'
|
||||||
import filelist from './filelist.vue'
|
import filelist from './filelist.vue'
|
||||||
import { SvgIcon } from '@/components/common'
|
import { SvgIcon } from '@/components/common'
|
||||||
import { deletekb, getkblist } from '@/api/chat'
|
import { deleteKb, getKbsList } from '@/api/chat'
|
||||||
import { idStore } from '@/store/modules/knowledgebaseid/id'
|
import { idStore } from '@/store/modules/knowledgebaseid/id'
|
||||||
|
|
||||||
const items = ref<any>([])
|
const items = ref<any>([])
|
||||||
const choice = ref('')
|
const choice = ref('')
|
||||||
const store = idStore()
|
const store = idStore()
|
||||||
|
|
||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
choice.value = store.knowledgeid
|
choice.value = store.knowledgeid
|
||||||
const res = await getkblist({})
|
const res = await getKbsList()
|
||||||
res.data.data.forEach((item: any) => {
|
res.data.data.forEach((item: any) => {
|
||||||
items.value.push({
|
items.value.push({
|
||||||
value: item,
|
value: item,
|
||||||
@ -52,8 +53,8 @@ const handleClick = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
async function handleDelete(item: any) {
|
async function handleDelete(item: any) {
|
||||||
await deletekb(item.value)
|
await deleteKb(item.value)
|
||||||
const res = await getkblist({})
|
const res = await getKbsList()
|
||||||
items.value = []
|
items.value = []
|
||||||
res.data.data.forEach((item: any) => {
|
res.data.data.forEach((item: any) => {
|
||||||
items.value.push({
|
items.value.push({
|
||||||
|
|||||||
53
webui.py
53
webui.py
@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||||||
yield history + [[query,
|
yield history + [[query,
|
||||||
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
||||||
else:
|
else:
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
|
||||||
streaming=streaming):
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
|
{"prompt": query, "history": history, "streaming": streaming})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
history[-1][-1] = resp
|
history[-1][-1] = resp
|
||||||
@ -101,11 +104,13 @@ def init_model():
|
|||||||
args_dict = vars(args)
|
args_dict = vars(args)
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
llm_model_ins = shared.loaderLLM()
|
llm_model_ins = shared.loaderLLM()
|
||||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
llm_model_ins.history_len = LLM_HISTORY_LEN
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||||
generator = local_doc_qa.llm.generatorAnswer("你好")
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
for answer_result in generator:
|
{"prompt": "你好", "history": [], "streaming": False})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
print(answer_result.llm_output)
|
print(answer_result.llm_output)
|
||||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
logger.info(reply)
|
logger.info(reply)
|
||||||
@ -141,7 +146,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
|
|||||||
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||||
filelist = []
|
filelist = []
|
||||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
if local_doc_qa.llm_model_chain and local_doc_qa.embeddings:
|
||||||
if isinstance(files, list):
|
if isinstance(files, list):
|
||||||
for file in files:
|
for file in files:
|
||||||
filename = os.path.split(file.name)[-1]
|
filename = os.path.split(file.name)[-1]
|
||||||
@ -165,8 +170,8 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
|
|||||||
|
|
||||||
def change_vs_name_input(vs_id, history):
|
def change_vs_name_input(vs_id, history):
|
||||||
if vs_id == "新建知识库":
|
if vs_id == "新建知识库":
|
||||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history,\
|
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history, \
|
||||||
gr.update(choices=[]), gr.update(visible=False)
|
gr.update(choices=[]), gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||||
if "index.faiss" in os.listdir(vs_path):
|
if "index.faiss" in os.listdir(vs_path):
|
||||||
@ -218,7 +223,12 @@ def change_chunk_conent(mode, label_conent, history):
|
|||||||
|
|
||||||
|
|
||||||
def add_vs_name(vs_name, chatbot):
|
def add_vs_name(vs_name, chatbot):
|
||||||
if vs_name in get_vs_list():
|
if vs_name is None or vs_name.strip() == "":
|
||||||
|
vs_status = "知识库名称不能为空,请重新填写知识库名称"
|
||||||
|
chatbot = chatbot + [[None, vs_status]]
|
||||||
|
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
||||||
|
visible=False), chatbot, gr.update(visible=False)
|
||||||
|
elif vs_name in get_vs_list():
|
||||||
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
|
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
|
||||||
chatbot = chatbot + [[None, vs_status]]
|
chatbot = chatbot + [[None, vs_status]]
|
||||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
||||||
@ -257,6 +267,7 @@ def reinit_vector_store(vs_id, history):
|
|||||||
def refresh_vs_list():
|
def refresh_vs_list():
|
||||||
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
|
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
|
||||||
|
|
||||||
|
|
||||||
def delete_file(vs_id, files_to_delete, chatbot):
|
def delete_file(vs_id, files_to_delete, chatbot):
|
||||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||||
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
|
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
|
||||||
@ -270,11 +281,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
|
|||||||
rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
|
rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
|
||||||
if "fail" in status:
|
if "fail" in status:
|
||||||
vs_status = "文件删除失败。"
|
vs_status = "文件删除失败。"
|
||||||
elif len(rested_files)>0:
|
elif len(rested_files) > 0:
|
||||||
vs_status = "文件删除成功。"
|
vs_status = "文件删除成功。"
|
||||||
else:
|
else:
|
||||||
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
|
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
|
||||||
logger.info(",".join(files_to_delete)+vs_status)
|
logger.info(",".join(files_to_delete) + vs_status)
|
||||||
chatbot = chatbot + [[None, vs_status]]
|
chatbot = chatbot + [[None, vs_status]]
|
||||||
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot
|
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot
|
||||||
|
|
||||||
@ -285,7 +296,8 @@ def delete_vs(vs_id, chatbot):
|
|||||||
status = f"成功删除知识库{vs_id}"
|
status = f"成功删除知识库{vs_id}"
|
||||||
logger.info(status)
|
logger.info(status)
|
||||||
chatbot = chatbot + [[None, status]]
|
chatbot = chatbot + [[None, status]]
|
||||||
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \
|
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(
|
||||||
|
visible=True), \
|
||||||
gr.update(visible=False), chatbot, gr.update(visible=False)
|
gr.update(visible=False), chatbot, gr.update(visible=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
@ -328,7 +340,8 @@ default_theme_args = dict(
|
|||||||
|
|
||||||
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
|
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
|
||||||
vs_path, file_status, model_status = gr.State(
|
vs_path, file_status, model_status = gr.State(
|
||||||
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
|
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(
|
||||||
|
""), gr.State(
|
||||||
model_status)
|
model_status)
|
||||||
gr.Markdown(webui_title)
|
gr.Markdown(webui_title)
|
||||||
with gr.Tab("对话"):
|
with gr.Tab("对话"):
|
||||||
@ -381,8 +394,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||||||
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
||||||
with gr.Tab("删除文件"):
|
with gr.Tab("删除文件"):
|
||||||
files_to_delete = gr.CheckboxGroup(choices=[],
|
files_to_delete = gr.CheckboxGroup(choices=[],
|
||||||
label="请从知识库已有文件中选择要删除的文件",
|
label="请从知识库已有文件中选择要删除的文件",
|
||||||
interactive=True)
|
interactive=True)
|
||||||
delete_file_button = gr.Button("从知识库中删除选中文件")
|
delete_file_button = gr.Button("从知识库中删除选中文件")
|
||||||
vs_refresh.click(fn=refresh_vs_list,
|
vs_refresh.click(fn=refresh_vs_list,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
@ -450,9 +463,9 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||||||
with vs_setting:
|
with vs_setting:
|
||||||
vs_refresh = gr.Button("更新已有知识库选项")
|
vs_refresh = gr.Button("更新已有知识库选项")
|
||||||
select_vs_test = gr.Dropdown(get_vs_list(),
|
select_vs_test = gr.Dropdown(get_vs_list(),
|
||||||
label="请选择要加载的知识库",
|
label="请选择要加载的知识库",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
|
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
|
||||||
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
||||||
lines=1,
|
lines=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
@ -492,8 +505,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
|||||||
inputs=[vs_name, chatbot],
|
inputs=[vs_name, chatbot],
|
||||||
outputs=[select_vs_test, vs_name, vs_add, file2vs, chatbot])
|
outputs=[select_vs_test, vs_name, vs_add, file2vs, chatbot])
|
||||||
select_vs_test.change(fn=change_vs_name_input,
|
select_vs_test.change(fn=change_vs_name_input,
|
||||||
inputs=[select_vs_test, chatbot],
|
inputs=[select_vs_test, chatbot],
|
||||||
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
||||||
load_file_button.click(get_vector_store,
|
load_file_button.click(get_vector_store,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
inputs=[select_vs_test, files, sentence_size, chatbot, vs_add, vs_add],
|
inputs=[select_vs_test, files, sentence_size, chatbot, vs_add, vs_add],
|
||||||
|
|||||||
417
webui_st.py
417
webui_st.py
@ -1,5 +1,5 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
# from st_btn_select import st_btn_select
|
from streamlit_chatbox import st_chatbox
|
||||||
import tempfile
|
import tempfile
|
||||||
###### 从webui借用的代码 #####
|
###### 从webui借用的代码 #####
|
||||||
###### 做了少量修改 #####
|
###### 做了少量修改 #####
|
||||||
@ -23,6 +23,7 @@ def get_vs_list():
|
|||||||
if not os.path.exists(KB_ROOT_PATH):
|
if not os.path.exists(KB_ROOT_PATH):
|
||||||
return lst_default
|
return lst_default
|
||||||
lst = os.listdir(KB_ROOT_PATH)
|
lst = os.listdir(KB_ROOT_PATH)
|
||||||
|
lst = [x for x in lst if os.path.isdir(os.path.join(KB_ROOT_PATH, x))]
|
||||||
if not lst:
|
if not lst:
|
||||||
return lst_default
|
return lst_default
|
||||||
lst.sort()
|
lst.sort()
|
||||||
@ -31,7 +32,6 @@ def get_vs_list():
|
|||||||
|
|
||||||
embedding_model_dict_list = list(embedding_model_dict.keys())
|
embedding_model_dict_list = list(embedding_model_dict.keys())
|
||||||
llm_model_dict_list = list(llm_model_dict.keys())
|
llm_model_dict_list = list(llm_model_dict.keys())
|
||||||
# flag_csv_logger = gr.CSVLogger()
|
|
||||||
|
|
||||||
|
|
||||||
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
|
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
|
||||||
@ -50,6 +50,9 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||||||
history[-1][-1] += source
|
history[-1][-1] += source
|
||||||
yield history, ""
|
yield history, ""
|
||||||
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
|
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
|
||||||
|
local_doc_qa.top_k = vector_search_top_k
|
||||||
|
local_doc_qa.chunk_conent = chunk_conent
|
||||||
|
local_doc_qa.chunk_size = chunk_size
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||||
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
|
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
|
||||||
source = "\n\n"
|
source = "\n\n"
|
||||||
@ -85,62 +88,16 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||||||
yield history + [[query,
|
yield history + [[query,
|
||||||
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
"请选择知识库后进行测试,当前未选择知识库。"]], ""
|
||||||
else:
|
else:
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
streaming=streaming):
|
{"prompt": query, "history": history, "streaming": streaming})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
history[-1][-1] = resp + (
|
history[-1][-1] = resp + (
|
||||||
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||||
yield history, ""
|
yield history, ""
|
||||||
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
|
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
|
||||||
# flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'):
|
|
||||||
local_doc_qa = LocalDocQA()
|
|
||||||
# 初始化消息
|
|
||||||
args = parser.parse_args()
|
|
||||||
args_dict = vars(args)
|
|
||||||
args_dict.update(model=llm_model)
|
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
|
||||||
llm_model_ins = shared.loaderLLM()
|
|
||||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
|
||||||
|
|
||||||
try:
|
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
|
||||||
embedding_model=embedding_model)
|
|
||||||
generator = local_doc_qa.llm.generatorAnswer("你好")
|
|
||||||
for answer_result in generator:
|
|
||||||
print(answer_result.llm_output)
|
|
||||||
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
|
||||||
logger.info(reply)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
|
||||||
if str(e) == "Unknown platform: darwin":
|
|
||||||
logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
|
|
||||||
" https://github.com/imClumsyPanda/langchain-ChatGLM")
|
|
||||||
else:
|
|
||||||
logger.info(reply)
|
|
||||||
return local_doc_qa
|
|
||||||
|
|
||||||
|
|
||||||
# 暂未使用到,先保留
|
|
||||||
# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
|
|
||||||
# try:
|
|
||||||
# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
|
||||||
# llm_model_ins.history_len = llm_history_len
|
|
||||||
# local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
|
||||||
# embedding_model=embedding_model,
|
|
||||||
# top_k=top_k)
|
|
||||||
# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
|
||||||
# logger.info(model_status)
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(e)
|
|
||||||
# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
|
||||||
# logger.info(model_status)
|
|
||||||
# return history + [[None, model_status]]
|
|
||||||
|
|
||||||
|
|
||||||
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||||
@ -148,7 +105,8 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
|
|||||||
filelist = []
|
filelist = []
|
||||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
|
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
|
||||||
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
|
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
|
||||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
qa = st.session_state.local_doc_qa
|
||||||
|
if qa.llm_model_chain and qa.embeddings:
|
||||||
if isinstance(files, list):
|
if isinstance(files, list):
|
||||||
for file in files:
|
for file in files:
|
||||||
filename = os.path.split(file.name)[-1]
|
filename = os.path.split(file.name)[-1]
|
||||||
@ -156,10 +114,10 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
|
|||||||
KB_ROOT_PATH, vs_id, "content", filename))
|
KB_ROOT_PATH, vs_id, "content", filename))
|
||||||
filelist.append(os.path.join(
|
filelist.append(os.path.join(
|
||||||
KB_ROOT_PATH, vs_id, "content", filename))
|
KB_ROOT_PATH, vs_id, "content", filename))
|
||||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(
|
vs_path, loaded_files = qa.init_knowledge_vector_store(
|
||||||
filelist, vs_path, sentence_size)
|
filelist, vs_path, sentence_size)
|
||||||
else:
|
else:
|
||||||
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
||||||
sentence_size)
|
sentence_size)
|
||||||
if len(loaded_files):
|
if len(loaded_files):
|
||||||
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
||||||
@ -177,10 +135,7 @@ knowledge_base_test_mode_info = ("【注意】\n\n"
|
|||||||
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
|
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
|
||||||
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
|
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
|
||||||
"""3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
|
"""3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
|
||||||
"4. 单条内容长度建议设置在100-150左右。\n\n"
|
"4. 单条内容长度建议设置在100-150左右。")
|
||||||
"5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中,"
|
|
||||||
"本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。"
|
|
||||||
"相关参数将在后续版本中支持本界面直接修改。")
|
|
||||||
|
|
||||||
|
|
||||||
webui_title = """
|
webui_title = """
|
||||||
@ -192,7 +147,7 @@ webui_title = """
|
|||||||
|
|
||||||
###### todo #####
|
###### todo #####
|
||||||
# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
|
# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
|
||||||
# 目前已经实现了local_doc_qa的全局化,后面要考虑shared。
|
# 目前已经实现了local_doc_qa和shared.loaderCheckPoint的全局化。
|
||||||
# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
|
# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
|
||||||
# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
|
# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
|
||||||
# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
|
# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
|
||||||
@ -201,25 +156,11 @@ webui_title = """
|
|||||||
|
|
||||||
###### 配置项 #####
|
###### 配置项 #####
|
||||||
class ST_CONFIG:
|
class ST_CONFIG:
|
||||||
user_bg_color = '#77ff77'
|
default_mode = "知识库问答"
|
||||||
user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7'
|
default_kb = ""
|
||||||
robot_bg_color = '#ccccee'
|
|
||||||
robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0'
|
|
||||||
default_mode = '知识库问答'
|
|
||||||
defalut_kb = ''
|
|
||||||
###### #####
|
###### #####
|
||||||
|
|
||||||
|
|
||||||
class MsgType:
|
|
||||||
'''
|
|
||||||
目前仅支持文本类型的输入输出,为以后多模态模型预留图像、视频、音频支持。
|
|
||||||
'''
|
|
||||||
TEXT = 1
|
|
||||||
IMAGE = 2
|
|
||||||
VIDEO = 3
|
|
||||||
AUDIO = 4
|
|
||||||
|
|
||||||
|
|
||||||
class TempFile:
|
class TempFile:
|
||||||
'''
|
'''
|
||||||
为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
|
为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
|
||||||
@ -229,132 +170,54 @@ class TempFile:
|
|||||||
self.name = path
|
self.name = path
|
||||||
|
|
||||||
|
|
||||||
def init_session():
|
|
||||||
st.session_state.setdefault('history', [])
|
|
||||||
|
|
||||||
|
|
||||||
# def get_query_params():
|
|
||||||
# '''
|
|
||||||
# 可以用url参数传递配置参数:llm_model, embedding_model, kb, mode。
|
|
||||||
# 该参数将覆盖model_config中的配置。处于安全考虑,目前只支持kb和mode
|
|
||||||
# 方便将固定的配置分享给特定的人。
|
|
||||||
# '''
|
|
||||||
# params = st.experimental_get_query_params()
|
|
||||||
# return {k: v[0] for k, v in params.items() if v}
|
|
||||||
|
|
||||||
|
|
||||||
def robot_say(msg, kb=''):
|
|
||||||
st.session_state['history'].append(
|
|
||||||
{'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb})
|
|
||||||
|
|
||||||
|
|
||||||
def user_say(msg):
|
|
||||||
st.session_state['history'].append(
|
|
||||||
{'is_user': True, 'type': MsgType.TEXT, 'content': msg})
|
|
||||||
|
|
||||||
|
|
||||||
def format_md(msg, is_user=False, bg_color='', margin='10%'):
|
|
||||||
'''
|
|
||||||
将文本消息格式化为markdown文本
|
|
||||||
'''
|
|
||||||
if is_user:
|
|
||||||
bg_color = bg_color or ST_CONFIG.user_bg_color
|
|
||||||
text = f'''
|
|
||||||
<div style="background:{bg_color};
|
|
||||||
margin-left:{margin};
|
|
||||||
word-break:break-all;
|
|
||||||
float:right;
|
|
||||||
padding:2%;
|
|
||||||
border-radius:2%;">
|
|
||||||
{msg}
|
|
||||||
</div>
|
|
||||||
'''
|
|
||||||
else:
|
|
||||||
bg_color = bg_color or ST_CONFIG.robot_bg_color
|
|
||||||
text = f'''
|
|
||||||
<div style="background:{bg_color};
|
|
||||||
margin-right:{margin};
|
|
||||||
word-break:break-all;
|
|
||||||
padding:2%;
|
|
||||||
border-radius:2%;">
|
|
||||||
{msg}
|
|
||||||
</div>
|
|
||||||
'''
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def message(msg,
|
|
||||||
is_user=False,
|
|
||||||
msg_type=MsgType.TEXT,
|
|
||||||
icon='',
|
|
||||||
bg_color='',
|
|
||||||
margin='10%',
|
|
||||||
kb='',
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
渲染单条消息。目前仅支持文本
|
|
||||||
'''
|
|
||||||
cols = st.columns([1, 10, 1])
|
|
||||||
empty = cols[1].empty()
|
|
||||||
if is_user:
|
|
||||||
icon = icon or ST_CONFIG.user_icon
|
|
||||||
bg_color = bg_color or ST_CONFIG.user_bg_color
|
|
||||||
cols[2].image(icon, width=40)
|
|
||||||
if msg_type == MsgType.TEXT:
|
|
||||||
text = format_md(msg, is_user, bg_color, margin)
|
|
||||||
empty.markdown(text, unsafe_allow_html=True)
|
|
||||||
else:
|
|
||||||
raise RuntimeError('only support text message now.')
|
|
||||||
else:
|
|
||||||
icon = icon or ST_CONFIG.robot_icon
|
|
||||||
bg_color = bg_color or ST_CONFIG.robot_bg_color
|
|
||||||
cols[0].image(icon, width=40)
|
|
||||||
if kb:
|
|
||||||
cols[0].write(f'({kb})')
|
|
||||||
if msg_type == MsgType.TEXT:
|
|
||||||
text = format_md(msg, is_user, bg_color, margin)
|
|
||||||
empty.markdown(text, unsafe_allow_html=True)
|
|
||||||
else:
|
|
||||||
raise RuntimeError('only support text message now.')
|
|
||||||
return empty
|
|
||||||
|
|
||||||
|
|
||||||
def output_messages(
|
|
||||||
user_bg_color='',
|
|
||||||
robot_bg_color='',
|
|
||||||
user_icon='',
|
|
||||||
robot_icon='',
|
|
||||||
):
|
|
||||||
with chat_box.container():
|
|
||||||
last_response = None
|
|
||||||
for msg in st.session_state['history']:
|
|
||||||
bg_color = user_bg_color if msg['is_user'] else robot_bg_color
|
|
||||||
icon = user_icon if msg['is_user'] else robot_icon
|
|
||||||
empty = message(msg['content'],
|
|
||||||
is_user=msg['is_user'],
|
|
||||||
icon=icon,
|
|
||||||
msg_type=msg['type'],
|
|
||||||
bg_color=bg_color,
|
|
||||||
kb=msg.get('kb', '')
|
|
||||||
)
|
|
||||||
if not msg['is_user']:
|
|
||||||
last_response = empty
|
|
||||||
return last_response
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_resource(show_spinner=False, max_entries=1)
|
@st.cache_resource(show_spinner=False, max_entries=1)
|
||||||
def load_model(llm_model: str, embedding_model: str):
|
def load_model(
|
||||||
|
llm_model: str = LLM_MODEL,
|
||||||
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
|
use_ptuning_v2: bool = USE_PTUNING_V2,
|
||||||
|
):
|
||||||
'''
|
'''
|
||||||
对应init_model,利用streamlit cache避免模型重复加载
|
对应init_model,利用streamlit cache避免模型重复加载
|
||||||
'''
|
'''
|
||||||
local_doc_qa = init_model(llm_model, embedding_model)
|
local_doc_qa = LocalDocQA()
|
||||||
robot_say('模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。\n请尽量不要刷新页面,以免模型出错或重复加载。')
|
# 初始化消息
|
||||||
|
args = parser.parse_args()
|
||||||
|
args_dict = vars(args)
|
||||||
|
args_dict.update(model=llm_model)
|
||||||
|
if shared.loaderCheckPoint is None: # avoid checkpoint reloading when reinit model
|
||||||
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
|
# shared.loaderCheckPoint.model_name is different by no_remote_model.
|
||||||
|
# if it is not set properly error occurs when reinit llm model(issue#473).
|
||||||
|
# as no_remote_model is removed from model_config, need workaround to set it automaticlly.
|
||||||
|
local_model_path = llm_model_dict.get(llm_model, {}).get('local_model_path') or ''
|
||||||
|
no_remote_model = os.path.isdir(local_model_path)
|
||||||
|
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
||||||
|
llm_model_ins.history_len = LLM_HISTORY_LEN
|
||||||
|
|
||||||
|
try:
|
||||||
|
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
||||||
|
embedding_model=embedding_model)
|
||||||
|
answer_result_stream_result = local_doc_qa.llm_model_chain(
|
||||||
|
{"prompt": "你好", "history": [], "streaming": False})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
|
print(answer_result.llm_output)
|
||||||
|
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||||
|
logger.info(reply)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
||||||
|
if str(e) == "Unknown platform: darwin":
|
||||||
|
logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:"
|
||||||
|
" https://github.com/imClumsyPanda/langchain-ChatGLM")
|
||||||
|
else:
|
||||||
|
logger.info(reply)
|
||||||
return local_doc_qa
|
return local_doc_qa
|
||||||
|
|
||||||
|
|
||||||
# @st.cache_data
|
# @st.cache_data
|
||||||
def answer(query, vs_path='', history=[], mode='', score_threshold=0,
|
def answer(query, vs_path='', history=[], mode='', score_threshold=0,
|
||||||
vector_search_top_k=5, chunk_conent=True, chunk_size=100, qa=None
|
vector_search_top_k=5, chunk_conent=True, chunk_size=100
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应get_answer,--利用streamlit cache缓存相同问题的答案--
|
对应get_answer,--利用streamlit cache缓存相同问题的答案--
|
||||||
@ -363,48 +226,24 @@ def answer(query, vs_path='', history=[], mode='', score_threshold=0,
|
|||||||
vector_search_top_k, chunk_conent, chunk_size)
|
vector_search_top_k, chunk_conent, chunk_size)
|
||||||
|
|
||||||
|
|
||||||
def load_vector_store(
|
def use_kb_mode(m):
|
||||||
vs_id,
|
return m in ["知识库问答", "知识库测试"]
|
||||||
files,
|
|
||||||
sentence_size=100,
|
|
||||||
history=[],
|
|
||||||
one_conent=None,
|
|
||||||
one_content_segmentation=None,
|
|
||||||
):
|
|
||||||
return get_vector_store(
|
|
||||||
local_doc_qa,
|
|
||||||
vs_id,
|
|
||||||
files,
|
|
||||||
sentence_size,
|
|
||||||
history,
|
|
||||||
one_conent,
|
|
||||||
one_content_segmentation,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# main ui
|
# main ui
|
||||||
st.set_page_config(webui_title, layout='wide')
|
st.set_page_config(webui_title, layout='wide')
|
||||||
init_session()
|
|
||||||
# params = get_query_params()
|
|
||||||
# llm_model = params.get('llm_model', LLM_MODEL)
|
|
||||||
# embedding_model = params.get('embedding_model', EMBEDDING_MODEL)
|
|
||||||
|
|
||||||
with st.spinner(f'正在加载模型({LLM_MODEL} + {EMBEDDING_MODEL}),请耐心等候...'):
|
|
||||||
local_doc_qa = load_model(LLM_MODEL, EMBEDDING_MODEL)
|
|
||||||
|
|
||||||
|
|
||||||
def use_kb_mode(m):
|
|
||||||
return m in ['知识库问答', '知识库测试']
|
|
||||||
|
|
||||||
|
chat_box = st_chatbox(greetings=["模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。"])
|
||||||
|
# 使用 help(st_chatbox) 查看自定义参数
|
||||||
|
|
||||||
# sidebar
|
# sidebar
|
||||||
modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试']
|
modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试']
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
def on_mode_change():
|
def on_mode_change():
|
||||||
m = st.session_state.mode
|
m = st.session_state.mode
|
||||||
robot_say(f'已切换到"{m}"模式')
|
chat_box.robot_say(f'已切换到"{m}"模式')
|
||||||
if m == '知识库测试':
|
if m == '知识库测试':
|
||||||
robot_say(knowledge_base_test_mode_info)
|
chat_box.robot_say(knowledge_base_test_mode_info)
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
try:
|
try:
|
||||||
@ -414,7 +253,7 @@ with st.sidebar:
|
|||||||
mode = st.selectbox('对话模式', modes, index,
|
mode = st.selectbox('对话模式', modes, index,
|
||||||
on_change=on_mode_change, key='mode')
|
on_change=on_mode_change, key='mode')
|
||||||
|
|
||||||
with st.expander('模型配置', '知识' not in mode):
|
with st.expander('模型配置', not use_kb_mode(mode)):
|
||||||
with st.form('model_config'):
|
with st.form('model_config'):
|
||||||
index = 0
|
index = 0
|
||||||
try:
|
try:
|
||||||
@ -423,9 +262,8 @@ with st.sidebar:
|
|||||||
pass
|
pass
|
||||||
llm_model = st.selectbox('LLM模型', llm_model_dict_list, index)
|
llm_model = st.selectbox('LLM模型', llm_model_dict_list, index)
|
||||||
|
|
||||||
no_remote_model = st.checkbox('加载本地模型', False)
|
|
||||||
use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False)
|
use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False)
|
||||||
use_lora = st.checkbox('使用lora微调的权重', False)
|
|
||||||
try:
|
try:
|
||||||
index = embedding_model_dict_list.index(EMBEDDING_MODEL)
|
index = embedding_model_dict_list.index(EMBEDDING_MODEL)
|
||||||
except:
|
except:
|
||||||
@ -435,42 +273,52 @@ with st.sidebar:
|
|||||||
|
|
||||||
btn_load_model = st.form_submit_button('重新加载模型')
|
btn_load_model = st.form_submit_button('重新加载模型')
|
||||||
if btn_load_model:
|
if btn_load_model:
|
||||||
local_doc_qa = load_model(llm_model, embedding_model)
|
local_doc_qa = load_model(llm_model, embedding_model, use_ptuning_v2)
|
||||||
|
|
||||||
if mode in ['知识库问答', '知识库测试']:
|
history_len = st.slider(
|
||||||
|
"LLM对话轮数", 1, 50, LLM_HISTORY_LEN)
|
||||||
|
|
||||||
|
if use_kb_mode(mode):
|
||||||
vs_list = get_vs_list()
|
vs_list = get_vs_list()
|
||||||
vs_list.remove('新建知识库')
|
vs_list.remove('新建知识库')
|
||||||
|
|
||||||
def on_new_kb():
|
def on_new_kb():
|
||||||
name = st.session_state.kb_name
|
name = st.session_state.kb_name
|
||||||
if name in vs_list:
|
if not name:
|
||||||
st.error(f'名为“{name}”的知识库已存在。')
|
st.sidebar.error(f'新建知识库名称不能为空!')
|
||||||
|
elif name in vs_list:
|
||||||
|
st.sidebar.error(f'名为“{name}”的知识库已存在。')
|
||||||
else:
|
else:
|
||||||
vs_list.append(name)
|
|
||||||
st.session_state.vs_path = name
|
st.session_state.vs_path = name
|
||||||
|
st.session_state.kb_name = ''
|
||||||
|
new_kb_dir = os.path.join(KB_ROOT_PATH, name)
|
||||||
|
if not os.path.exists(new_kb_dir):
|
||||||
|
os.makedirs(new_kb_dir)
|
||||||
|
st.sidebar.success(f'名为“{name}”的知识库创建成功,您可以开始添加文件。')
|
||||||
|
|
||||||
def on_vs_change():
|
def on_vs_change():
|
||||||
robot_say(f'已加载知识库: {st.session_state.vs_path}')
|
chat_box.robot_say(f'已加载知识库: {st.session_state.vs_path}')
|
||||||
with st.expander('知识库配置', True):
|
with st.expander('知识库配置', True):
|
||||||
cols = st.columns([12, 10])
|
cols = st.columns([12, 10])
|
||||||
kb_name = cols[0].text_input(
|
kb_name = cols[0].text_input(
|
||||||
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed')
|
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed', key='kb_name')
|
||||||
cols[1].button('新建知识库', on_click=on_new_kb)
|
cols[1].button('新建知识库', on_click=on_new_kb)
|
||||||
|
index = 0
|
||||||
|
try:
|
||||||
|
index = vs_list.index(ST_CONFIG.default_kb)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
vs_path = st.selectbox(
|
vs_path = st.selectbox(
|
||||||
'选择知识库', vs_list, on_change=on_vs_change, key='vs_path')
|
'选择知识库', vs_list, index, on_change=on_vs_change, key='vs_path')
|
||||||
|
|
||||||
st.text('')
|
st.text('')
|
||||||
|
|
||||||
score_threshold = st.slider(
|
score_threshold = st.slider(
|
||||||
'知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD)
|
'知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD)
|
||||||
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
|
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
|
||||||
history_len = st.slider(
|
|
||||||
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
|
|
||||||
local_doc_qa.llm.set_history_len(history_len)
|
|
||||||
chunk_conent = st.checkbox('启用上下文关联', False)
|
chunk_conent = st.checkbox('启用上下文关联', False)
|
||||||
st.text('')
|
|
||||||
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
|
|
||||||
chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE)
|
chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE)
|
||||||
|
st.text('')
|
||||||
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
|
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
|
||||||
files = st.file_uploader('上传知识文件',
|
files = st.file_uploader('上传知识文件',
|
||||||
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
|
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
|
||||||
@ -483,56 +331,61 @@ with st.sidebar:
|
|||||||
with open(file, 'wb') as fp:
|
with open(file, 'wb') as fp:
|
||||||
fp.write(f.getvalue())
|
fp.write(f.getvalue())
|
||||||
file_list.append(TempFile(file))
|
file_list.append(TempFile(file))
|
||||||
_, _, history = load_vector_store(
|
_, _, history = get_vector_store(
|
||||||
vs_path, file_list, sentence_size, [], None, None)
|
vs_path, file_list, sentence_size, [], None, None)
|
||||||
st.session_state.files = []
|
st.session_state.files = []
|
||||||
|
|
||||||
|
|
||||||
# main body
|
# load model after params rendered
|
||||||
chat_box = st.empty()
|
with st.spinner(f"正在加载模型({llm_model} + {embedding_model}),请耐心等候..."):
|
||||||
|
local_doc_qa = load_model(
|
||||||
|
llm_model,
|
||||||
|
embedding_model,
|
||||||
|
use_ptuning_v2,
|
||||||
|
)
|
||||||
|
local_doc_qa.llm_model_chain.history_len = history_len
|
||||||
|
if use_kb_mode(mode):
|
||||||
|
local_doc_qa.chunk_conent = chunk_conent
|
||||||
|
local_doc_qa.chunk_size = chunk_size
|
||||||
|
# local_doc_qa.llm_model_chain.temperature = temperature # 这样设置temperature似乎不起作用
|
||||||
|
st.session_state.local_doc_qa = local_doc_qa
|
||||||
|
|
||||||
with st.form('my_form', clear_on_submit=True):
|
# input form
|
||||||
|
with st.form("my_form", clear_on_submit=True):
|
||||||
cols = st.columns([8, 1])
|
cols = st.columns([8, 1])
|
||||||
question = cols[0].text_input(
|
question = cols[0].text_area(
|
||||||
'temp', key='input_question', label_visibility='collapsed')
|
'temp', key='input_question', label_visibility='collapsed')
|
||||||
|
|
||||||
def on_send():
|
if cols[1].form_submit_button("发送"):
|
||||||
q = st.session_state.input_question
|
chat_box.user_say(question)
|
||||||
if q:
|
history = []
|
||||||
user_say(q)
|
if mode == "LLM 对话":
|
||||||
|
chat_box.robot_say("正在思考...")
|
||||||
|
chat_box.output_messages()
|
||||||
|
for history, _ in answer(question,
|
||||||
|
history=[],
|
||||||
|
mode=mode):
|
||||||
|
chat_box.update_last_box_text(history[-1][-1])
|
||||||
|
elif use_kb_mode(mode):
|
||||||
|
chat_box.robot_say(f"正在查询 [{vs_path}] ...")
|
||||||
|
chat_box.output_messages()
|
||||||
|
for history, _ in answer(question,
|
||||||
|
vs_path=os.path.join(
|
||||||
|
KB_ROOT_PATH, vs_path, 'vector_store'),
|
||||||
|
history=[],
|
||||||
|
mode=mode,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
vector_search_top_k=top_k,
|
||||||
|
chunk_conent=chunk_conent,
|
||||||
|
chunk_size=chunk_size):
|
||||||
|
chat_box.update_last_box_text(history[-1][-1])
|
||||||
|
else:
|
||||||
|
chat_box.robot_say(f"正在执行Bing搜索...")
|
||||||
|
chat_box.output_messages()
|
||||||
|
for history, _ in answer(question,
|
||||||
|
history=[],
|
||||||
|
mode=mode):
|
||||||
|
chat_box.update_last_box_text(history[-1][-1])
|
||||||
|
|
||||||
if mode == 'LLM 对话':
|
# st.write(chat_box.history)
|
||||||
robot_say('正在思考...')
|
chat_box.output_messages()
|
||||||
last_response = output_messages()
|
|
||||||
for history, _ in answer(q,
|
|
||||||
history=[],
|
|
||||||
mode=mode):
|
|
||||||
last_response.markdown(
|
|
||||||
format_md(history[-1][-1], False),
|
|
||||||
unsafe_allow_html=True
|
|
||||||
)
|
|
||||||
elif use_kb_mode(mode):
|
|
||||||
robot_say('正在思考...', vs_path)
|
|
||||||
last_response = output_messages()
|
|
||||||
for history, _ in answer(q,
|
|
||||||
vs_path=os.path.join(
|
|
||||||
KB_ROOT_PATH, vs_path, "vector_store"),
|
|
||||||
history=[],
|
|
||||||
mode=mode,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
vector_search_top_k=top_k,
|
|
||||||
chunk_conent=chunk_conent,
|
|
||||||
chunk_size=chunk_size):
|
|
||||||
last_response.markdown(
|
|
||||||
format_md(history[-1][-1], False, 'ligreen'),
|
|
||||||
unsafe_allow_html=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
robot_say('正在思考...')
|
|
||||||
last_response = output_messages()
|
|
||||||
st.session_state['history'][-1]['content'] = history[-1][-1]
|
|
||||||
submit = cols[1].form_submit_button('发送', on_click=on_send)
|
|
||||||
|
|
||||||
output_messages()
|
|
||||||
|
|
||||||
# st.write(st.session_state['history'])
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user