merge PR 659: make chat api support streaming

This commit is contained in:
liunux4odoo 2023-07-25 13:16:43 +08:00
commit 18d453cc18
33 changed files with 1160 additions and 868 deletions

5
.gitignore vendored
View File

@ -174,4 +174,7 @@ embedding/*
pyrightconfig.json
loader/tmp_files
flagged/*
flagged/*
ptuning-v2/*.json
ptuning-v2/*.bin

View File

@ -23,13 +23,17 @@
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
🐳 Docker镜像registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 (感谢 @InkSong🌲
💻 运行方式docker run -d -p 80:7860 --gpus all registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM)
📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
## 变更日志
参见 [变更日志](docs/CHANGELOG.md)。
参见 [版本更新日志](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)。
## 硬件需求
@ -60,6 +64,23 @@
本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB也可修改为在 CPU 中运行。
## Docker 整合包
🐳 Docker镜像地址`registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 `🌲
💻 一行命令运行:
```shell
docker run -d -p 80:7860 --gpus all registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0
```
- 该版本镜像大小`25.2G`,使用[v0.1.16](https://github.com/imClumsyPanda/langchain-ChatGLM/releases/tag/v0.1.16),以`nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04`为基础镜像
- 该版本内置两个`embedding`模型:`m3e-base``text2vec-large-chinese`,内置`fastchat+chatglm-6b`
- 该版本目标为方便一键部署使用请确保您已经在Linux发行版上安装了NVIDIA驱动程序
- 请注意您不需要在主机系统上安装CUDA工具包但需要安装`NVIDIA Driver`以及`NVIDIA Container Toolkit`,请参考[安装指南](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
- 首次拉取和启动均需要一定时间,首次启动时请参照下图使用`docker logs -f <container id>`查看日志
- 如遇到启动过程卡在`Waiting..`步骤,建议使用`docker exec -it <container id> bash`进入`/logs/`目录查看对应阶段日志
![](img/docker_logs.png)
## Docker 部署
为了能让容器使用主机GPU资源需要在主机上安装 [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit)。具体安装步骤如下:
```shell
@ -198,12 +219,17 @@ Web UI 可以实现如下功能:
- [ ] 知识图谱/图数据库接入
- [ ] Agent 实现
- [x] 增加更多 LLM 模型支持
- [x] [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
- [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
- [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8)
- [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4)
- [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe)
- [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2)
- [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft)
- [x] [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1)
- [x] [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b)
- [x] [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [x] [lmsys/vicuna-13b-delta-v1.1](https://huggingface.co/lmsys/vicuna-13b-delta-v1.1)
- [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm
- [x] 增加更多 Embedding 模型支持
- [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
@ -221,7 +247,7 @@ Web UI 可以实现如下功能:
- [x] 选择知识库开始问答
- [x] 上传文件/文件夹至知识库
- [x] 知识库测试
- [ ] 删除知识库中文件
- [x] 删除知识库中文件
- [x] 支持搜索引擎问答
- [ ] 增加 API 支持
- [x] 利用 fastapi 实现 API 部署方式
@ -229,7 +255,7 @@ Web UI 可以实现如下功能:
- [x] VUE 前端
## 项目交流群
<img src="img/qr_code_33.jpg" alt="二维码" width="300" height="300" />
<img src="img/qr_code_45.jpg" alt="二维码" width="300" height="300" />
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

231
api.py
View File

@ -1,10 +1,11 @@
#encoding:utf-8
import argparse
import json
import os
import shutil
from typing import List, Optional
import urllib
import asyncio
import nltk
import pydantic
import uvicorn
@ -55,7 +56,7 @@ class ListDocsResponse(BaseResponse):
class ChatMessage(BaseModel):
question: str = pydantic.Field(..., description="Question text")
response: str = pydantic.Field(..., description="Response text")
history: List[List[str]] = pydantic.Field(..., description="History text")
history: List[List[Optional[str]]] = pydantic.Field(..., description="History text")
source_documents: List[str] = pydantic.Field(
..., description="List of source documents and their scores"
)
@ -80,23 +81,37 @@ class ChatMessage(BaseModel):
}
def get_folder_path(local_doc_id: str):
return os.path.join(KB_ROOT_PATH, local_doc_id, "content")
def get_kb_path(local_doc_id: str):
return os.path.join(KB_ROOT_PATH, local_doc_id)
def get_doc_path(local_doc_id: str):
return os.path.join(get_kb_path(local_doc_id), "content")
def get_vs_path(local_doc_id: str):
return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store")
return os.path.join(get_kb_path(local_doc_id), "vector_store")
def get_file_path(local_doc_id: str, doc_name: str):
return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name)
return os.path.join(get_doc_path(local_doc_id), doc_name)
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True
async def upload_file(
file: UploadFile = File(description="A single binary file"),
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
):
saved_path = get_folder_path(knowledge_base_id)
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me", data=[])
saved_path = get_doc_path(knowledge_base_id)
if not os.path.exists(saved_path):
os.makedirs(saved_path)
@ -126,21 +141,25 @@ async def upload_files(
],
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
):
saved_path = get_folder_path(knowledge_base_id)
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me", data=[])
saved_path = get_doc_path(knowledge_base_id)
if not os.path.exists(saved_path):
os.makedirs(saved_path)
filelist = []
for file in files:
file_content = ''
file_path = os.path.join(saved_path, file.filename)
file_content = file.file.read()
file_content = await file.read()
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
continue
with open(file_path, "ab+") as f:
with open(file_path, "wb") as f:
f.write(file_content)
filelist.append(file_path)
if filelist:
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id))
vs_path = get_vs_path(knowledge_base_id)
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path)
if len(loaded_files):
file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success"
return BaseResponse(code=200, msg=file_status)
@ -164,16 +183,24 @@ async def list_kbs():
async def list_docs(
knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1")
knowledge_base_id: str = Query(..., description="Knowledge Base Name", example="kb1")
):
local_doc_folder = get_folder_path(knowledge_base_id)
if not validate_kb_name(knowledge_base_id):
return ListDocsResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
kb_path = get_kb_path(knowledge_base_id)
local_doc_folder = get_doc_path(knowledge_base_id)
if not os.path.exists(kb_path):
return ListDocsResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found", data=[])
if not os.path.exists(local_doc_folder):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
all_doc_names = [
doc
for doc in os.listdir(local_doc_folder)
if os.path.isfile(os.path.join(local_doc_folder, doc))
]
all_doc_names = []
else:
all_doc_names = [
doc
for doc in os.listdir(local_doc_folder)
if os.path.isfile(os.path.join(local_doc_folder, doc))
]
return ListDocsResponse(data=all_doc_names)
@ -182,11 +209,15 @@ async def delete_kb(
description="Knowledge Base Name",
example="kb1"),
):
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me")
# TODO: 确认是否支持批量删除知识库
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(get_folder_path(knowledge_base_id)):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
shutil.rmtree(get_folder_path(knowledge_base_id))
kb_path = get_kb_path(knowledge_base_id)
if not os.path.exists(kb_path):
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
shutil.rmtree(kb_path)
return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success")
@ -195,27 +226,30 @@ async def delete_doc(
description="Knowledge Base Name",
example="kb1"),
doc_name: str = Query(
None, description="doc name", example="doc_name_1.pdf"
..., description="doc name", example="doc_name_1.pdf"
),
):
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(get_folder_path(knowledge_base_id)):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
if not os.path.exists(get_kb_path(knowledge_base_id)):
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
doc_path = get_file_path(knowledge_base_id, doc_name)
if os.path.exists(doc_path):
os.remove(doc_path)
remain_docs = await list_docs(knowledge_base_id)
if len(remain_docs.data) == 0:
shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True)
shutil.rmtree(get_kb_path(knowledge_base_id), ignore_errors=True)
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
else:
status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
if "success" in status:
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
else:
return BaseResponse(code=1, msg=f"document {doc_name} delete fail")
return BaseResponse(code=500, msg=f"document {doc_name} delete fail")
else:
return BaseResponse(code=1, msg=f"document {doc_name} not found")
return BaseResponse(code=404, msg=f"document {doc_name} not found")
async def update_doc(
@ -223,23 +257,26 @@ async def update_doc(
description="知识库名",
example="kb1"),
old_doc: str = Query(
None, description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
..., description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf"
),
new_doc: UploadFile = File(description="待上传文件"),
):
if not validate_kb_name(knowledge_base_id):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_id = urllib.parse.unquote(knowledge_base_id)
if not os.path.exists(get_folder_path(knowledge_base_id)):
return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"}
if not os.path.exists(get_kb_path(knowledge_base_id)):
return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
doc_path = get_file_path(knowledge_base_id, old_doc)
if not os.path.exists(doc_path):
return BaseResponse(code=1, msg=f"document {old_doc} not found")
return BaseResponse(code=404, msg=f"document {old_doc} not found")
else:
os.remove(doc_path)
delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id))
if "fail" in delete_status:
return BaseResponse(code=1, msg=f"document {old_doc} delete failed")
return BaseResponse(code=500, msg=f"document {old_doc} delete failed")
else:
saved_path = get_folder_path(knowledge_base_id)
saved_path = get_doc_path(knowledge_base_id)
if not os.path.exists(saved_path):
os.makedirs(saved_path)
@ -267,8 +304,8 @@ async def update_doc(
async def local_doc_chat(
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
question: str = Body(..., description="Question", example="工伤保险是什么?"),
stream: bool = Body(False, description="是否开启流式输出默认false有些模型可能不支持。"),
history: List[List[str]] = Body(
streaming: bool = Body(False, description="是否开启流式输出默认false有些模型可能不支持。"),
history: List[List[Optional[str]]] = Body(
[],
description="History of previous questions and answers",
example=[
@ -281,7 +318,7 @@ async def local_doc_chat(
):
vs_path = get_vs_path(knowledge_base_id)
if not os.path.exists(vs_path):
# return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found")
# return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found")
return ChatMessage(
question=question,
response=f"Knowledge base {knowledge_base_id} not found",
@ -289,7 +326,7 @@ async def local_doc_chat(
source_documents=[],
)
else:
if (stream):
if (streaming):
def generate_answer ():
last_print_len = 0
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
@ -300,7 +337,7 @@ async def local_doc_chat(
return StreamingResponse(generate_answer())
else:
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
pass
@ -314,14 +351,14 @@ async def local_doc_chat(
return ChatMessage(
question=question,
response=resp["result"],
history=next_history,
history=history,
source_documents=source_documents,
)
async def bing_search_chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"),
history: Optional[List[List[str]]] = Body(
history: Optional[List[List[Optional[str]]]] = Body(
[],
description="History of previous questions and answers",
example=[
@ -351,8 +388,8 @@ async def bing_search_chat(
async def chat(
question: str = Body(..., description="Question", example="工伤保险是什么?"),
stream: bool = Body(False, description="是否开启流式输出默认false有些模型可能不支持。"),
history: List[List[str]] = Body(
streaming: bool = Body(False, description="是否开启流式输出默认false有些模型可能不支持。"),
history: List[List[Optional[str]]] = Body(
[],
description="History of previous questions and answers",
example=[
@ -363,19 +400,20 @@ async def chat(
],
),
):
if (stream):
if (streaming):
def generate_answer ():
last_print_len = 0
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})
for answer_result in answer_result_stream_result['answer_result_stream']:
yield answer_result.llm_output["answer"][last_print_len:]
last_print_len = len(answer_result.llm_output["answer"])
return StreamingResponse(generate_answer())
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history,
streaming=True):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
@ -386,9 +424,22 @@ async def chat(
history=history,
source_documents=[],
)
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
return ChatMessage(
question=question,
response=resp,
history=history,
source_documents=[],
)
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
async def stream_chat(websocket: WebSocket):
await websocket.accept()
turn = 1
while True:
@ -408,6 +459,7 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
await asyncio.sleep(0)
await websocket.send_text(resp["result"][last_print_len:])
last_print_len = len(resp["result"])
@ -430,17 +482,51 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
)
turn += 1
async def stream_chat_bing(websocket: WebSocket):
"""
基于bing搜索的流式问答
"""
await websocket.accept()
turn = 1
while True:
input_json = await websocket.receive_json()
question, history = input_json["question"], input_json["history"]
await websocket.send_json({"question": question, "turn": turn, "flag": "start"})
last_print_len = 0
for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history, streaming=True):
await websocket.send_text(resp["result"][last_print_len:])
last_print_len = len(resp["result"])
source_documents = [
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}\n\n{doc.page_content}\n\n"""
f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
await websocket.send_text(
json.dumps(
{
"question": question,
"turn": turn,
"flag": "end",
"sources_documents": source_documents,
},
ensure_ascii=False,
)
)
turn += 1
async def document():
return RedirectResponse(url="/docs")
def api_start(host, port):
def api_start(host, port, **kwargs):
global app
global local_doc_qa
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
app = FastAPI()
# Add CORS middleware to allow all origins
@ -454,21 +540,28 @@ def api_start(host, port):
allow_methods=["*"],
allow_headers=["*"],
)
app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat)
# 修改了stream_chat的接口直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接在请求体中选择knowledge_base_id
app.websocket("/local_doc_qa/stream_chat")(stream_chat)
app.get("/", response_model=BaseResponse)(document)
app.get("/", response_model=BaseResponse, summary="swagger 文档")(document)
app.post("/chat", response_model=ChatMessage)(chat)
# 增加基于bing搜索的流式问答
# 需要说明的是如果想测试websocket的流式问答需要使用支持websocket的测试工具如postman,insomnia
# 强烈推荐开源的insomnia
# 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing
app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing)
app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file)
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat)
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse)(list_kbs)
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse)(delete_kb)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_doc)
app.post("/local_doc_qa/update_file", response_model=BaseResponse)(update_doc)
app.post("/chat", response_model=ChatMessage, summary="与模型对话")(chat)
app.post("/local_doc_qa/upload_file", response_model=BaseResponse, summary="上传文件到知识库")(upload_file)
app.post("/local_doc_qa/upload_files", response_model=BaseResponse, summary="批量上传文件到知识库")(upload_files)
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage, summary="与知识库对话")(local_doc_chat)
app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage, summary="与必应搜索对话")(bing_search_chat)
app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse, summary="获取知识库列表")(list_kbs)
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse, summary="获取知识库内的文件列表")(list_docs)
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库")(delete_kb)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse, summary="删除知识库内的文件")(delete_doc)
app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(update_doc)
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(
@ -477,15 +570,21 @@ def api_start(host, port):
embedding_device=EMBEDDING_DEVICE,
top_k=VECTOR_SEARCH_TOP_K,
)
uvicorn.run(app, host=host, port=port)
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app, host=host, port=port, ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"))
else:
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = None
args = parser.parse_args()
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
api_start(args.host, args.port)
api_start(args.host, args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile)

View File

@ -8,7 +8,6 @@ from typing import List
from utils import torch_gc
from tqdm import tqdm
from pypinyin import lazy_pinyin
from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
from models.base import (BaseAnswer,
AnswerResult)
from models.loader.args import parser
@ -18,6 +17,7 @@ from agent import bing_search
from langchain.docstore.document import Document
from functools import lru_cache
from textsplitter.zh_title_enhance import zh_title_enhance
from langchain.chains.base import Chain
# patch HuggingFaceEmbeddings to make it hashable
@ -58,6 +58,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE):
if filepath.lower().endswith(".md"):
loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load()
@ -66,10 +67,14 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter)
elif filepath.lower().endswith(".pdf"):
# 暂且将paddle相关的loader改为动态加载可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
from loader import UnstructuredPaddlePDFLoader
loader = UnstructuredPaddlePDFLoader(filepath)
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter)
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"):
# 暂且将paddle相关的loader改为动态加载可以在不上传pdf/image知识文件的前提下使用protobuf=4.x
from loader import UnstructuredPaddleImageLoader
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter)
@ -119,7 +124,7 @@ def search_result2docs(search_results):
class LocalDocQA:
llm: BaseAnswer = None
llm_model_chain: Chain = None
embeddings: object = None
top_k: int = VECTOR_SEARCH_TOP_K
chunk_size: int = CHUNK_SIZE
@ -129,10 +134,10 @@ class LocalDocQA:
def init_cfg(self,
embedding_model: str = EMBEDDING_MODEL,
embedding_device=EMBEDDING_DEVICE,
llm_model: BaseAnswer = None,
llm_model: Chain = None,
top_k=VECTOR_SEARCH_TOP_K,
):
self.llm = llm_model
self.llm_model_chain = llm_model
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
model_kwargs={'device': embedding_device})
self.top_k = top_k
@ -200,6 +205,7 @@ class LocalDocQA:
return vs_path, loaded_files
else:
logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。")
return None, loaded_files
def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size):
@ -235,8 +241,10 @@ class LocalDocQA:
else:
prompt = query
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
answer_result_stream_result = self.llm_model_chain(
{"prompt": prompt, "history": chat_history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
@ -275,8 +283,10 @@ class LocalDocQA:
result_docs = search_result2docs(results)
prompt = generate_prompt(result_docs, query)
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history,
streaming=streaming):
answer_result_stream_result = self.llm_model_chain(
{"prompt": prompt, "history": chat_history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][0] = query
@ -295,7 +305,7 @@ class LocalDocQA:
def update_file_from_vector_store(self,
filepath: str or List[str],
vs_path,
docs: List[Document],):
docs: List[Document], ):
vector_store = load_vector_store(vs_path, self.embeddings)
status = vector_store.update_doc(filepath, docs)
return status
@ -319,7 +329,6 @@ if __name__ == "__main__":
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(llm_model=llm_model_ins)

View File

@ -1,34 +0,0 @@
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from typing import Any, List
class MyEmbeddings(HuggingFaceEmbeddings):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
texts = list(map(lambda x: x.replace("\n", " "), texts))
embeddings = self.client.encode(texts, normalize_embeddings=True)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
text = text.replace("\n", " ")
embedding = self.client.encode(text, normalize_embeddings=True)
return embedding.tolist()

View File

@ -1,121 +0,0 @@
from langchain.vectorstores import FAISS
from typing import Any, Callable, List, Optional, Tuple, Dict
from langchain.docstore.document import Document
from langchain.docstore.base import Docstore
from langchain.vectorstores.utils import maximal_marginal_relevance
from langchain.embeddings.base import Embeddings
import uuid
from langchain.docstore.in_memory import InMemoryDocstore
import numpy as np
def dependable_faiss_import() -> Any:
"""Import faiss if available, otherwise raise error."""
try:
import faiss
except ImportError:
raise ValueError(
"Could not import faiss python package. "
"Please install it with `pip install faiss` "
"or `pip install faiss-cpu` (depending on Python version)."
)
return faiss
class FAISSVS(FAISS):
def __init__(self,
embedding_function: Callable[..., Any],
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str]):
super().__init__(embedding_function, index, docstore, index_to_docstore_id)
def max_marginal_relevance_search_by_vector(
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Returns:
List of Documents with scores selected by maximal marginal relevance.
"""
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k)
# -1 happens when not enough docs are returned.
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
mmr_selected = maximal_marginal_relevance(
np.array([embedding], dtype=np.float32), embeddings, k=k
)
selected_indices = [indices[0][i] for i in mmr_selected]
selected_scores = [scores[0][i] for i in mmr_selected]
docs = []
for i, score in zip(selected_indices, selected_scores):
if i == -1:
# This happens when not enough docs are returned.
continue
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
docs.append((doc, score))
return docs
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Returns:
List of Documents with scores selected by maximal marginal relevance.
"""
embedding = self.embedding_function(query)
docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k)
return docs
@classmethod
def __from(
cls,
texts: List[str],
embeddings: List[List[float]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> FAISS:
faiss = dependable_faiss_import()
index = faiss.IndexFlatIP(len(embeddings[0]))
index.add(np.array(embeddings, dtype=np.float32))
# # my code, for speeding up search
# quantizer = faiss.IndexFlatL2(len(embeddings[0]))
# index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100)
# index.train(np.array(embeddings, dtype=np.float32))
# index.add(np.array(embeddings, dtype=np.float32))
documents = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata))
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))}
docstore = InMemoryDocstore(
{index_to_id[i]: doc for i, doc in enumerate(documents)}
)
return cls(embedding.embed_query, index, docstore, index_to_id)

6
cli.py
View File

@ -42,7 +42,9 @@ def start():
@start.command(name="api", context_settings=dict(help_option_names=['-h', '--help']))
@click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.')
@click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.')
def start_api(ip, port):
@click.option('-k', '--ssl_keyfile', type=int, help='enable api https/wss service, specify the ssl keyfile path.')
@click.option('-c', '--ssl_certfile', type=int, help='enable api https/wss service, specify the ssl certificate file path.')
def start_api(ip, port, **kwargs):
# 调用api_start之前需要先loadCheckPoint,并传入加载检查点的参数,
# 理论上可以用click包进行包装但过于繁琐改动较大
# 此处仍用parser包并以models.loader.args.DEFAULT_ARGS的参数为默认参数
@ -51,7 +53,7 @@ def start_api(ip, port):
from models.loader import LoaderCheckPoint
from models.loader.args import DEFAULT_ARGS
shared.loaderCheckPoint = LoaderCheckPoint(DEFAULT_ARGS)
api_start(host=ip, port=port)
api_start(host=ip, port=port, **kwargs)
# # 通过cli.py调用cli_demo时需要在cli.py里初始化模型否则会报错
# langchain-ChatGLM: error: unrecognized arguments: start cli

View File

@ -23,11 +23,33 @@ def main():
top_k=VECTOR_SEARCH_TOP_K)
vs_path = None
while not vs_path:
print("注意输入的路径是完整的文件路径例如knowledge_base/`knowledge_base_id`/content/file.md多个路径用英文逗号分割")
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
# 判断 filepath 是否为空,如果为空的话,重新让用户输入,防止用户误触回车
if not filepath:
continue
vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath)
# 支持加载多个文件
filepath = filepath.split(",")
# filepath错误的返回为None, 如果直接用原先的vs_path,_ = local_doc_qa.init_knowledge_vector_store(filepath)
# 会直接导致TypeError: cannot unpack non-iterable NoneType object而使得程序直接退出
# 因此需要先加一层判断,保证程序能继续运行
temp,loaded_files = local_doc_qa.init_knowledge_vector_store(filepath)
if temp is not None:
vs_path = temp
# 如果loaded_files和len(filepath)不一致,则说明部分文件没有加载成功
# 如果是路径错误,则应该支持重新加载
if len(loaded_files) != len(filepath):
reload_flag = eval(input("部分文件加载失败若提示路径不存在可重新加载是否重新加载输入True或False: "))
if reload_flag:
vs_path = None
continue
print(f"the loaded vs_path is 加载的vs_path为: {vs_path}")
else:
print("load file failed, re-input your local knowledge file path 请重新输入本地知识文件路径")
history = []
while True:
query = input("Input your question 请输入问题:")

View File

@ -177,3 +177,22 @@ download_with_progressbar(url, tmp_path)
Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a new connection: [Errno 110] Connection timed out`
这是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG--!
---
Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`
疑为chatglm的quantization的问题或torch版本差异问题针对已经变为Parameter的torch.zeros矩阵也执行Parameter操作从而抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在chatglm-项目的原始文件中的quantization.py文件374行改为
```
try:
self.weight =Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
except Exception as e:
pass
```
如果上述方式不起作用,则在.cache/hugggingface/modules/目录下针对chatglm项目的原始文件中的quantization.py文件执行上述操作若软链接不止一个按照错误提示选择正确的路径。
虽然模型可以顺利加载但在cpu上仍存在推理失败的可能即针对每个问题模型一直输出gugugugu。
因此最好不要试图用cpu加载量化模型原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。

View File

@ -44,4 +44,12 @@ $ pip install -r requirements.txt
$ python loader/image_loader.py
```
注:使用 `langchain.document_loaders.UnstructuredFileLoader` 进行非结构化文件接入时,可能需要依据文档进行其他依赖包的安装,请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)。
## llama-cpp模型调用的说明
1. 首先从huggingface hub中下载对应的模型如 [https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/) 的 [ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)建议使用huggingface_hub库的snapshot_download下载。
2. 将下载的模型重命名。通过huggingface_hub下载的模型会被重命名为随机序列因此需要重命名为原始文件名如[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)。
3. 基于下载模型的ggml的加载时间推测对应的llama-cpp版本下载对应的llama-cpp-python库的wheel文件实测[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)与llama-cpp-python库兼容,然后手动安装wheel文件。
4. 将下载的模型信息写入configs/model_config.py文件里 `llm_model_dict`中,注意保证参数的兼容性,一些参数组合可能会报错.

37
docs/启动API服务.md Normal file
View File

@ -0,0 +1,37 @@
# 启动API服务
## 通过py文件启动
可以通过直接执行`api.py`文件启动API服务默认以ip:0.0.0.0和port:7861启动http和ws服务。
```shell
python api.py
```
同时启动时支持StartOption所列的模型加载参数同时还支持IP和端口设置。
```shell
python api.py --model-name chatglm-6b-int8 --port 7862
```
## 通过cli.bat/cli.sh启动
也可以通过命令行控制文件继续启动。
```shell
cli.sh api --help
```
其他可设置参数和上述py文件启动方式相同。
# 以https、wss启动API服务
## 本地创建ssl相关证书文件
如果没有正式签发的CA证书可以[安装mkcert](https://github.com/FiloSottile/mkcert#installation)工具, 然后用如下指令生成本地CA证书
```shell
mkcert -install
mkcert api.example.com 47.123.123.123 localhost 127.0.0.1 ::1
```
默认回车保存在当前目录下会有以生成指令第一个域名命名为前缀命名的两个pem文件。
附带两个文件参数启动即可。
````shell
python api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem
./cli.sh api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem
````
此外可以通过前置Nginx转发实现类似效果可另行查阅相关资料。

BIN
img/docker_logs.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 143 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 154 KiB

BIN
img/qr_code_45.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 192 KiB

View File

@ -5,9 +5,6 @@ from langchain.document_loaders.unstructured import UnstructuredFileLoader
from paddleocr import PaddleOCR
import os
import nltk
from configs.model_config import NLTK_DATA_PATH
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
"""Loader that uses unstructured to load image files, such as PNGs and JPGs."""
@ -35,6 +32,10 @@ class UnstructuredPaddleImageLoader(UnstructuredFileLoader):
if __name__ == "__main__":
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.jpg")
loader = UnstructuredPaddleImageLoader(filepath, mode="elements")
docs = loader.load()

View File

@ -1,4 +1,4 @@
from .chatglm_llm import ChatGLM
from .llama_llm import LLamaLLM
from .moss_llm import MOSSLLM
from .fastchat_openai_llm import FastChatOpenAILLM
from .chatglm_llm import ChatGLMLLMChain
from .llama_llm import LLamaLLMChain
from .fastchat_openai_llm import FastChatOpenAILLMChain
from .moss_llm import MOSSLLMChain

View File

@ -1,13 +1,15 @@
from models.base.base import (
AnswerResult,
BaseAnswer
)
BaseAnswer,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from models.base.remote_rpc_model import (
RemoteRpcModel
)
__all__ = [
"AnswerResult",
"BaseAnswer",
"RemoteRpcModel",
"AnswerResultStream",
"AnswerResultQueueSentinelTokenListenerQueue"
]

View File

@ -1,16 +1,30 @@
from abc import ABC, abstractmethod
from typing import Optional, List
from typing import Any, Dict, List, Optional, Generator
import traceback
from collections import deque
from queue import Queue
from threading import Thread
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.loader import LoaderCheckPoint
from pydantic import BaseModel
import torch
import transformers
from models.loader import LoaderCheckPoint
class AnswerResult:
class ListenerToken:
"""
观测结果
"""
input_ids: torch.LongTensor
_scores: torch.FloatTensor
def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor):
self.input_ids = input_ids
self._scores = _scores
class AnswerResult(BaseModel):
"""
消息实体
"""
@ -18,6 +32,122 @@ class AnswerResult:
llm_output: Optional[dict] = None
class AnswerResultStream:
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, answerResult: AnswerResult):
if self.callback_func is not None:
self.callback_func(answerResult)
class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria):
"""
定义模型stopping_criteria 监听者在每次响应时将队列数据同步到AnswerResult
实现此监听器的目的是不同模型的预测输出可能不是矢量信息hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数
通过给 StoppingCriteriaList指定模型生成答案时停止的条件每个 StoppingCriteria 对象表示一个停止条件
当每轮预测任务开始时StoppingCriteria都会收到相同的预测结果最终由下层实现类确认是否结束
输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测以实现更加精细的控制
"""
listenerQueue: deque = deque(maxlen=1)
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool:
"""
每次响应时将数据添加到响应队列
:param input_ids:
:param _scores:
:param kwargs:
:return:
"""
self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores))
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
"""
def __init__(self, func, kwargs={}):
self.mfunc = func
self.q = Queue()
self.sentinel = object()
self.kwargs = kwargs
self.stop_now = False
def _callback(val):
"""
模型输出预测结果收集
通过定义generate_with_callback收集器AnswerResultStream收集模型预测的AnswerResult响应结果最终由下层实现类确认是否结束
结束条件包含如下
1模型预测结束收集器self.q队列收到 self.sentinel标识
2在处理迭代器队列消息时返回了break跳出迭代器触发了StopIteration事件
3模型预测出错
因为当前类是迭代器所以在for in 中执行了break后 __exit__ 方法会被调用最终stop_now属性会被更新然后抛出异常结束预测行为
迭代器收集的行为如下
创建Iteratorize迭代对象
定义generate_with_callback收集器AnswerResultStream
启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer
_generate_answer通过generate_with_callback定义的收集器收集上游checkpoint包装的AnswerResult消息体
由于self.q是阻塞模式每次预测后会被消费后才会执行下次预测
这时generate_with_callback会被阻塞
主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费
1消息为上游checkpoint包装的AnswerResult消息体返回下游处理
2消息为self.sentinel标识抛出StopIteration异常
主线程Iteratorize对象__exit__收到消息最终stop_now属性会被更新
异步线程检测stop_now属性被更新抛出异常结束预测行为
迭代行为结束
:param val:
:return:
"""
if self.stop_now:
raise ValueError
self.q.put(val)
def gen():
try:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
self.q.put(self.sentinel)
self.thread = Thread(target=gen)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
"""
暂无实现
:return:
"""
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
""" break 后会执行 """
self.stop_now = True
class BaseAnswer(ABC):
"""上层业务包装器.用于结果生成统一api调用"""
@ -25,17 +155,23 @@ class BaseAnswer(ABC):
@abstractmethod
def _check_point(self) -> LoaderCheckPoint:
"""Return _check_point of llm."""
def generatorAnswer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,) -> Generator[Any, str, bool]:
def generate_with_callback(callback=None, **kwargs):
kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback)
self._generate_answer(**kwargs)
@property
@abstractmethod
def _history_len(self) -> int:
"""Return _history_len of llm."""
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs)
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
for answerResult in generator:
yield answerResult
@abstractmethod
def set_history_len(self, history_len: int) -> None:
"""Return _history_len of llm."""
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
pass

View File

@ -1,83 +1,117 @@
from abc import ABC
from langchain.llms.base import LLM
from typing import Optional, List
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator
from langchain.callbacks.manager import CallbackManagerForChainRun
# from transformers.generation.logits_process import LogitsProcessor
# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
# import torch
import transformers
class ChatGLM(BaseAnswer, LLM, ABC):
class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
max_token: int = 10000
temperature: float = 0.01
top_p = 0.9
# 相关度
top_p = 0.4
# 候选词数量
top_k = 10
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 10
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "ChatGLM"
def _chain_type(self) -> str:
return "ChatGLMLLMChain"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def _history_len(self) -> int:
return self.history_len
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
:meta private:
"""
return [self.prompt_key]
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer,
prompt,
history=[],
max_length=self.max_token,
temperature=self.temperature
)
print(f"response:{response}")
print(f"+++++++++++++++++++++++++++++++++++")
return response
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
stopping_criteria_list.append(listenerQueue)
if streaming:
history += [[]]
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
self.checkPoint.tokenizer,
prompt,
history=history[-self.history_len:-1] if self.history_len > 1 else [],
history=history[-self.history_len:-1] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
)):
# self.checkPoint.clear_torch_cache()
history[-1] = [prompt, stream_resp]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": stream_resp}
yield answer_result
generate_with_callback(answer_result)
self.checkPoint.clear_torch_cache()
else:
response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer,
prompt,
history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
stopping_criteria=stopping_criteria_list
)
self.checkPoint.clear_torch_cache()
history += [[prompt, response]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
yield answer_result
generate_with_callback(answer_result)

View File

@ -1,15 +1,37 @@
from abc import ABC
import requests
from typing import Optional, List
from langchain.llms.base import LLM
from langchain.chains.base import Chain
from typing import (
Any, Dict, List, Optional, Generator, Collection, Set,
Callable,
Tuple,
Union)
from models.loader import LoaderCheckPoint
from models.base import (RemoteRpcModel,
AnswerResult)
from typing import (
Collection,
Dict
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.base import (BaseAnswer,
RemoteRpcModel,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from pydantic import Extra, Field, root_validator
from openai import (
ChatCompletion
)
import openai
import logging
import torch
import transformers
logger = logging.getLogger(__name__)
def _build_message_template() -> Dict[str, str]:
@ -22,34 +44,88 @@ def _build_message_template() -> Dict[str, str]:
}
class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
# 将历史对话数组转换为文本格式
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
build_messages: Collection[Dict[str, str]] = []
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = "You are a helpful assistant."
build_messages.append(system_build_message)
if history:
for i, (user, assistant) in enumerate(history):
if user:
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = user
build_messages.append(user_build_message)
if not assistant:
raise RuntimeError("历史数据结构不正确")
system_build_message = _build_message_template()
system_build_message['role'] = 'assistant'
system_build_message['content'] = assistant
build_messages.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = query
build_messages.append(user_build_message)
return build_messages
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
client: Any
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 6
api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b"
max_token: int = 10000
temperature: float = 0.01
top_p = 0.9
checkPoint: LoaderCheckPoint = None
history = []
# history = []
history_len: int = 10
api_key: str = ""
def __init__(self, checkPoint: LoaderCheckPoint = None):
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self,
checkPoint: LoaderCheckPoint = None,
# api_base_url:str="http://localhost:8000/v1",
# model_name:str="chatglm-6b",
# api_key:str=""
):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "FastChat"
def _chain_type(self) -> str:
return "LLamaLLMChain"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def _history_len(self) -> int:
return self.history_len
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _api_key(self) -> str:
@ -60,7 +136,7 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
return self.api_base_url
def set_api_key(self, api_key: str):
pass
self.api_key = api_key
def set_api_base_url(self, api_base_url: str):
self.api_base_url = api_base_url
@ -68,70 +144,116 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC):
def call_model_name(self, model_name):
self.model_name = model_name
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _create_retry_decorator(self) -> Callable[[Any], Any]:
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs.get(self.history_key, [])
streaming = inputs.get(self.streaming_key, False)
prompt = inputs[self.prompt_key]
stop = inputs.get("stop", "stop")
print(f"__call:{prompt}")
try:
import openai
# Not support yet
openai.api_key = "EMPTY"
# openai.api_key = "EMPTY"
openai.api_key = self.api_key
openai.api_base = self.api_base_url
except ImportError:
self.client = openai.ChatCompletion
except AttributeError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=self.build_message_list(prompt)
)
print(f"response:{completion.choices[0].message.content}")
print(f"+++++++++++++++++++++++++++++++++++")
return completion.choices[0].message.content
msg = build_message_list(prompt, history=history)
# 将历史对话数组转换为文本格式
def build_message_list(self, query) -> Collection[Dict[str, str]]:
build_message_list: Collection[Dict[str, str]] = []
history = self.history[-self.history_len:] if self.history_len > 0 else []
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = response
build_message_list.append(user_build_message)
build_message_list.append(system_build_message)
if streaming:
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
out_str = ""
for stream_resp in self.completion_with_retry(
messages=msg,
**params
):
role = stream_resp["choices"][0]["delta"].get("role", "")
token = stream_resp["choices"][0]["delta"].get("content", "")
out_str += token
history[-1] = [prompt, out_str]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": out_str}
generate_with_callback(answer_result)
else:
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = query
build_message_list.append(user_build_message)
return build_message_list
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
try:
import openai
# Not support yet
openai.api_key = "EMPTY"
openai.api_base = self.api_base_url
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
response = self.completion_with_retry(
messages=msg,
**params
)
# create a chat completion
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=self.build_message_list(prompt)
)
role = response["choices"][0]["message"].get("role", "")
content = response["choices"][0]["message"].get("content", "")
history += [[prompt, content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": content}
generate_with_callback(answer_result)
history += [[prompt, completion.choices[0].message.content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": completion.choices[0].message.content}
yield answer_result
if __name__ == "__main__":
chain = FastChatOpenAILLMChain()
chain.set_api_key("EMPTY")
# chain.set_api_base_url("https://api.openai.com/v1")
# chain.call_model_name("gpt-3.5-turbo")
answer_result_stream_result = chain({"streaming": True,
"prompt": "你好",
"history": []
})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
print(resp)

View File

@ -1,26 +1,32 @@
from abc import ABC
from langchain.llms.base import LLM
import random
import torch
import transformers
from abc import ABC
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from typing import Optional, List, Dict, Any
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
def __call__(self, input_ids: Union[torch.LongTensor, list],
scores: Union[torch.FloatTensor, list]) -> torch.FloatTensor:
# llama-cpp模型返回的是list,为兼容性考虑需要判断input_ids和scores的类型将list转换为torch.Tensor
input_ids = torch.tensor(input_ids) if isinstance(input_ids, list) else input_ids
scores = torch.tensor(scores) if isinstance(scores, list) else scores
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
class LLamaLLM(BaseAnswer, LLM, ABC):
class LLamaLLMChain(BaseAnswer, Chain, ABC):
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 3
@ -34,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
min_length: int = 0
logits_processor: LogitsProcessorList = None
stopping_criteria: Optional[StoppingCriteriaList] = None
eos_token_id: Optional[int] = [2]
state: object = {'max_new_tokens': 50,
'seed': 1,
'temperature': 0, 'top_p': 0.1,
'top_k': 40, 'typical_p': 1,
'repetition_penalty': 1.2,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'min_length': 0,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False,
'truncation_length': 2048, 'custom_stopping_strings': '',
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False,
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None',
'pre_layer': 0, 'gpu_memory_0': 0}
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "LLamaLLM"
def _chain_type(self) -> str:
return "LLamaLLMChain"
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _check_point(self) -> LoaderCheckPoint:
@ -104,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
formatted_history += "### Human{}\n### Assistant".format(query)
return formatted_history
def prepare_inputs_for_generation(self,
input_ids: torch.LongTensor):
"""
预生成注意力掩码和 输入序列中每个位置的索引的张量
# TODO 没有思路
:return:
"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device)
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
attention_mask = self.get_masks(input_ids, input_ids.device)
position_ids = self.get_position_ids(
input_ids,
device=input_ids.device,
mask_positions=mask_positions
)
return input_ids, position_ids, attention_mask
@property
def _history_len(self) -> int:
return self.history_len
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
# Create the StoppingCriteriaList with the stopping strings
self.stopping_criteria = transformers.StoppingCriteriaList()
# 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult
listenerQueue = AnswerResultQueueSentinelTokenListenerQueue()
self.stopping_criteria.append(listenerQueue)
# TODO 需要实现chat对话模块和注意力模型目前_call为langchain的LLM拓展的api默认为无提示词模式如果需要操作注意力模型可以参考chat_glm的实现
soft_prompt = self.history_to_text(query=prompt, history=history)
if self.logits_processor is None:
self.logits_processor = LogitsProcessorList()
self.logits_processor.append(InvalidScoreLogitsProcessor())
@ -151,35 +155,36 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
"logits_processor": self.logits_processor}
# 向量转换
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
input_ids = self.encode(soft_prompt, add_bos_token=self.checkPoint.tokenizer.add_bos_token,
truncation_length=self.max_new_tokens)
gen_kwargs.update({'inputs': input_ids})
# 注意力掩码
# gen_kwargs.update({'attention_mask': attention_mask})
# gen_kwargs.update({'position_ids': position_ids})
if self.stopping_criteria is None:
self.stopping_criteria = transformers.StoppingCriteriaList()
# 观测输出
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
# llama-cpp模型的参数与transformers的参数字段有较大差异直接调用会返回不支持的字段错误
# 因此需要先判断模型是否是llama-cpp模型然后取gen_kwargs与模型generate方法字段的交集
# 仅将交集字段传给模型以保证兼容性
# todo llama-cpp模型在本框架下兼容性较差后续可以考虑重写一个llama_cpp_llm.py模块
if "llama_cpp" in self.checkPoint.model.__str__():
import inspect
output_ids = self.checkPoint.model.generate(**gen_kwargs)
common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args) & set(
gen_kwargs.keys())
common_kwargs = {key: gen_kwargs[key] for key in common_kwargs_keys}
# ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入响应很慢慢到哭泣
# ?为什么会不支持GPU呢不应该啊
output_ids = torch.tensor(
[list(self.checkPoint.model.generate(input_id_i.cpu(), **common_kwargs)) for input_id_i in input_ids])
else:
output_ids = self.checkPoint.model.generate(**gen_kwargs)
new_tokens = len(output_ids[0]) - len(input_ids[0])
reply = self.decode(output_ids[0][-new_tokens:])
print(f"response:{reply}")
print(f"+++++++++++++++++++++++++++++++++++")
return reply
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
# TODO 需要实现chat对话模块和注意力模型目前_call为langchain的LLM拓展的api默认为无提示词模式如果需要操作注意力模型可以参考chat_glm的实现
softprompt = self.history_to_text(prompt,history=history)
response = self._call(prompt=softprompt, stop=['\n###'])
answer_result = AnswerResult()
answer_result.history = history + [[prompt, response]]
answer_result.llm_output = {"answer": response}
yield answer_result
history += [[prompt, reply]]
answer_result.history = history
answer_result.llm_output = {"answer": reply}
generate_with_callback(answer_result)

View File

@ -1,3 +1,4 @@
import argparse
import os
from configs.model_config import *
@ -43,7 +44,8 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th
parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint")
parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint")
# Accelerate/transformers
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
help='Load the model with 8-bit precision.')

View File

@ -20,6 +20,7 @@ class LoaderCheckPoint:
no_remote_model: bool = False
# 模型名称
model_name: str = None
pretrained_model_name: str = None
tokenizer: object = None
# 模型全路径
model_path: str = None
@ -35,11 +36,11 @@ class LoaderCheckPoint:
# 因此主要的解决思路是清理环境变量里PATH下的不匹配的cuda版本一劳永逸的方法是
# 0. 在终端执行`pip uninstall bitsandbytes`
# 1. 删除.bashrc文件下关于PATH的条目
# 2. 在终端执行 `echo $PATH >> .bashrc`
# 2. 在终端执行 `echo $PATH >> .bashrc`
# 3. 删除.bashrc文件下PATH中关于不匹配的cuda版本路径
# 4. 在终端执行`source .bashrc`
# 5. 再执行`pip install bitsandbytes`
load_in_8bit: bool = False
is_llamacpp: bool = False
bf16: bool = False
@ -67,43 +68,49 @@ class LoaderCheckPoint:
self.load_in_8bit = params.get('load_in_8bit', False)
self.bf16 = params.get('bf16', False)
def _load_model_config(self, model_name):
def _load_model_config(self):
if self.model_path:
self.model_path = re.sub("\s", "", self.model_path)
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
if self.no_remote_model:
raise ValueError(
"本地模型local_model_path未配置路径"
)
else:
checkpoint = self.pretrained_model_name
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
print(f"load_model_config {checkpoint}...")
try:
return model_config
model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
return model_config
except Exception as e:
print(e)
return checkpoint
def _load_model(self, model_name):
def _load_model(self):
"""
加载自定义位置的model
:param model_name:
:return:
"""
print(f"Loading {model_name}...")
t0 = time.time()
if self.model_path:
self.model_path = re.sub("\s", "", self.model_path)
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
if self.no_remote_model:
raise ValueError(
"本地模型local_model_path未配置路径"
)
else:
checkpoint = self.pretrained_model_name
print(f"Loading {checkpoint}...")
self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0
if 'chatglm' in model_name.lower():
if 'chatglm' in self.model_name.lower() or "chatyuan" in self.model_name.lower():
LoaderClass = AutoModel
else:
LoaderClass = AutoModelForCausalLM
@ -126,8 +133,14 @@ class LoaderCheckPoint:
.half()
.cuda()
)
# 支持自定义cuda设备
elif ":" in self.llm_device:
model = LoaderClass.from_pretrained(checkpoint,
config=self.model_config,
torch_dtype=torch.bfloat16 if self.bf16 else torch.float16,
trust_remote_code=True).half().to(self.llm_device)
else:
from accelerate import dispatch_model
from accelerate import dispatch_model, infer_auto_device_map
model = LoaderClass.from_pretrained(checkpoint,
config=self.model_config,
@ -135,12 +148,22 @@ class LoaderCheckPoint:
trust_remote_code=True).half()
# 可传入device_map自定义每张卡的部署情况
if self.device_map is None:
if 'chatglm' in model_name.lower():
if 'chatglm' in self.model_name.lower() and not "chatglm2" in self.model_name.lower():
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
elif 'moss' in model_name.lower():
self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name)
elif 'moss' in self.model_name.lower():
self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint)
else:
self.device_map = self.chatglm_auto_configure_device_map(num_gpus)
# 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败
# 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试GPU负载也相对均衡
from accelerate.utils import get_balanced_memory
max_memory = get_balanced_memory(model,
dtype=torch.int8 if self.load_in_8bit else None,
low_zero=False,
no_split_module_classes=model._no_split_modules)
self.device_map = infer_auto_device_map(model,
dtype=torch.float16 if not self.load_in_8bit else torch.int8,
max_memory=max_memory,
no_split_module_classes=model._no_split_modules)
model = dispatch_model(model, device_map=self.device_map)
else:
@ -156,7 +179,7 @@ class LoaderCheckPoint:
elif self.is_llamacpp:
try:
from models.extensions.llamacpp_model_alternative import LlamaCppModel
from llama_cpp import Llama
except ImportError as exc:
raise ValueError(
@ -167,7 +190,16 @@ class LoaderCheckPoint:
model_file = list(checkpoint.glob('ggml*.bin'))[0]
print(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
model = Llama(model_path=model_file._str)
# 实测llama-cpp-vicuna13b-q5_1的AutoTokenizer加载tokenizer的速度极慢应存在优化空间
# 但需要对huggingface的AutoTokenizer进行优化
# tokenizer = model.tokenizer
# todo 此处调用AutoTokenizer的tokenizer但后续可以测试自带tokenizer是不是兼容
# * -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
return model, tokenizer
elif self.load_in_8bit:
@ -194,7 +226,7 @@ class LoaderCheckPoint:
llm_int8_enable_fp32_cpu_offload=False)
with init_empty_weights():
model = LoaderClass.from_config(self.model_config,trust_remote_code = True)
model = LoaderClass.from_config(self.model_config, trust_remote_code=True)
model.tie_weights()
if self.device_map is not None:
params['device_map'] = self.device_map
@ -257,10 +289,21 @@ class LoaderCheckPoint:
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
device_map = {f'{layer_prefix}.word_embeddings': 0,
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
f'base_model.model.lm_head': 0, }
encode = ""
if 'chatglm2' in self.model_name:
device_map = {
f"{layer_prefix}.embedding.word_embeddings": 0,
f"{layer_prefix}.rotary_pos_emb": 0,
f"{layer_prefix}.output_layer": 0,
f"{layer_prefix}.encoder.final_layernorm": 0,
f"base_model.model.output_layer": 0
}
encode = ".encoder"
else:
device_map = {f'{layer_prefix}.word_embeddings': 0,
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
f'base_model.model.lm_head': 0, }
used = 2
gpu_target = 0
for i in range(num_trans_layers):
@ -268,12 +311,12 @@ class LoaderCheckPoint:
gpu_target += 1
used = 0
assert gpu_target < num_gpus
device_map[f'{layer_prefix}.layers.{i}'] = gpu_target
device_map[f'{layer_prefix}{encode}.layers.{i}'] = gpu_target
used += 1
return device_map
def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]:
def moss_auto_configure_device_map(self, num_gpus: int, checkpoint) -> Dict[str, int]:
try:
from accelerate import init_empty_weights
@ -288,16 +331,6 @@ class LoaderCheckPoint:
"`pip install bitsandbytes``pip install accelerate`."
) from exc
if self.model_path:
checkpoint = Path(f'{self.model_path}')
else:
if not self.no_remote_model:
checkpoint = model_name
else:
raise ValueError(
"本地模型local_model_path未配置路径"
)
cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM",
pretrained_model_name_or_path=checkpoint)
@ -385,7 +418,7 @@ class LoaderCheckPoint:
print(
"如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
elif torch.has_cuda:
device_id = "0" if torch.cuda.is_available() else None
device_id = "0" if torch.cuda.is_available() and (":" not in self.llm_device) else None
CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
@ -404,33 +437,37 @@ class LoaderCheckPoint:
def reload_model(self):
self.unload_model()
self.model_config = self._load_model_config(self.model_name)
self.model_config = self._load_model_config()
if self.use_ptuning_v2:
try:
prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r')
prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r')
prefix_encoder_config = json.loads(prefix_encoder_file.read())
prefix_encoder_file.close()
self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
self.model_config.prefix_projection = prefix_encoder_config['prefix_projection']
except Exception as e:
print(e)
print("加载PrefixEncoder config.json失败")
self.model, self.tokenizer = self._load_model(self.model_name)
self.model, self.tokenizer = self._load_model()
if self.lora:
self._add_lora_to_model([self.lora])
if self.use_ptuning_v2:
try:
prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin'))
prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin'))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float()
print("加载ptuning检查点成功")
except Exception as e:
print(e)
print("加载PrefixEncoder模型参数失败")
self.model = self.model.eval()
# llama-cpp模型至少vicuna-13b的eval方法就是自身其没有eval方法
if not self.is_llamacpp:
self.model = self.model.eval()

View File

@ -1,12 +1,20 @@
from abc import ABC
from langchain.llms.base import LLM
from typing import Optional, List
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Generator, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
import torch
import transformers
import torch
# todo 建议重写instruction,在该instruction下各模型的表现比较差
META_INSTRUCTION = \
"""You are an AI assistant whose name is MOSS.
- MOSS is a conversational language model that is developed by Fudan University. It is designed to be helpful, honest, and harmless.
@ -21,49 +29,76 @@ META_INSTRUCTION = \
"""
class MOSSLLM(BaseAnswer, LLM, ABC):
# todo 在MOSSLLM类下各模型的响应速度很慢后续要检查一下原因
class MOSSLLMChain(BaseAnswer, Chain, ABC):
max_token: int = 2048
temperature: float = 0.7
top_p = 0.8
# history = []
checkPoint: LoaderCheckPoint = None
history_len: int = 10
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "MOSS"
def _chain_type(self) -> str:
return "MOSSLLMChain"
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def set_history_len(self) -> int:
return self.history_len
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _set_history_len(self, history_len: int) -> None:
self.history_len = history_len
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
pass
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
print(f"__call:{prompt}")
if len(history) > 0:
history = history[-self.history_len:] if self.history_len > 0 else []
prompt_w_history = str(history)
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
else:
prompt_w_history = META_INSTRUCTION
prompt_w_history = META_INSTRUCTION.replace("MOSS", self.checkPoint.model_name.split("/")[-1])
prompt_w_history += '<|Human|>: ' + prompt + '<eoh>'
inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt")
with torch.no_grad():
# max_length似乎可以设的小一些而repetion_penalty应大一些否则chatyuan,bloom等模型为满足max会重复输出
#
outputs = self.checkPoint.model.generate(
inputs.input_ids.cuda(),
attention_mask=inputs.attention_mask.cuda(),
@ -76,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC):
num_return_sequences=1,
eos_token_id=106068,
pad_token_id=self.checkPoint.tokenizer.pad_token_id)
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True)
self.checkPoint.clear_torch_cache()
history += [[prompt, response]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
yield answer_result
generate_with_callback(answer_result)

View File

@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
if use_ptuning_v2:
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
# 如果指定了参数,则使用参数的配置
if llm_model:
llm_model_info = llm_model_dict[llm_model]
if loaderCheckPoint.no_remote_model:
loaderCheckPoint.model_name = llm_model_info['name']
else:
loaderCheckPoint.model_name = llm_model_info['pretrained_model_name']
loaderCheckPoint.model_name = llm_model_info['name']
loaderCheckPoint.pretrained_model_name = llm_model_info['pretrained_model_name']
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
@ -44,4 +43,5 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
if 'FastChatOpenAILLM' in llm_model_info["provides"]:
modelInsLLM.set_api_base_url(llm_model_info['api_base_url'])
modelInsLLM.call_model_name(llm_model_info['name'])
modelInsLLM.set_api_key(llm_model_info['api_key'])
return modelInsLLM

View File

@ -11,7 +11,7 @@ beautifulsoup4
icetk
cpm_kernels
faiss-cpu
gradio==3.28.3
gradio==3.37.0
fastapi~=0.95.0
uvicorn~=0.21.1
pypinyin~=0.48.0
@ -23,9 +23,13 @@ openai
#accelerate~=0.18.0
#peft~=0.3.0
#bitsandbytes; platform_system != "Windows"
#llama-cpp-python==0.1.34; platform_system != "Windows"
#https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
# 要调用llama-cpp模型如vicuma-13b量化模型需要安装llama-cpp-python库
# but!!! 实测pip install 不好使需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载
# 而且注意不同时期的ggml格式并不容!!!因此需要安装的llama-cpp-python版本也不一致需要手动测试才能确定
# 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容
# 不过本项目模型加载的方式控制的比较严格与llama-cpp-python的兼容性较差很多参数设定不能使用
# 建议如非必要还是不要使用llama-cpp
torch~=2.0.0
pydantic~=1.10.7
starlette~=0.26.1

View File

@ -1,39 +0,0 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
import asyncio
from argparse import Namespace
from models.loader.args import parser
from models.loader import LoaderCheckPoint
import models.shared as shared
async def dispatch(args: Namespace):
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
history = [
("which city is this?", "tokyo"),
("why?", "she's japanese"),
]
for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history,
streaming=False):
resp = answer_result.llm_output["answer"]
print(resp)
if __name__ == '__main__':
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'fastchat-chatglm-6b', '--no-remote-model'])
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(dispatch(args))

View File

@ -16,6 +16,24 @@ export const chatfile = (params: any) => {
})
}
export const getKbsList = () => {
return api({
url: '/local_doc_qa/list_knowledge_base',
method: 'get',
})
}
export const deleteKb = (knowledge_base_id: any) => {
return api({
url: '/local_doc_qa/delete_knowledge_base',
method: 'delete',
params: {
knowledge_base_id,
},
})
}
export const getfilelist = (knowledge_base_id: any) => {
return api({
url: '/local_doc_qa/list_files',
@ -35,8 +53,8 @@ export const bing_search = (params: any) => {
export const deletefile = (params: any) => {
return api({
url: '/local_doc_qa/delete_file',
method: 'post',
data: JSON.stringify(params),
method: 'delete',
params,
})
}
export const web_url = () => {

View File

@ -555,7 +555,7 @@ const options = computed(() => {
return common
})
function handleSelect(key: 'copyText' | 'delete' | 'toggleRenderType') {
function handleSelect(key: string) {
if (key == '清除会话') {
handleClear()
}
@ -658,7 +658,6 @@ function searchfun() {
<NDropdown
v-if="isMobile"
:trigger="isMobile ? 'click' : 'hover'"
:placement="!inversion ? 'right' : 'left'"
:options="options"
@select="handleSelect"
>

View File

@ -3,15 +3,16 @@ import { NButton, NForm, NFormItem, NInput, NPopconfirm } from 'naive-ui'
import { onMounted, ref } from 'vue'
import filelist from './filelist.vue'
import { SvgIcon } from '@/components/common'
import { deletekb, getkblist } from '@/api/chat'
import { deleteKb, getKbsList } from '@/api/chat'
import { idStore } from '@/store/modules/knowledgebaseid/id'
const items = ref<any>([])
const choice = ref('')
const store = idStore()
onMounted(async () => {
choice.value = store.knowledgeid
const res = await getkblist({})
const res = await getKbsList()
res.data.data.forEach((item: any) => {
items.value.push({
value: item,
@ -52,8 +53,8 @@ const handleClick = () => {
}
}
async function handleDelete(item: any) {
await deletekb(item.value)
const res = await getkblist({})
await deleteKb(item.value)
const res = await getKbsList()
items.value = []
res.data.data.forEach((item: any) => {
items.value.push({

View File

@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], ""
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": query, "history": history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][-1] = resp
@ -101,11 +104,13 @@ def init_model():
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
llm_model_ins.history_len = LLM_HISTORY_LEN
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins)
generator = local_doc_qa.llm.generatorAnswer("你好")
for answer_result in generator:
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply)
@ -141,7 +146,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
filelist = []
if local_doc_qa.llm and local_doc_qa.embeddings:
if local_doc_qa.llm_model_chain and local_doc_qa.embeddings:
if isinstance(files, list):
for file in files:
filename = os.path.split(file.name)[-1]
@ -165,8 +170,8 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
def change_vs_name_input(vs_id, history):
if vs_id == "新建知识库":
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history,\
gr.update(choices=[]), gr.update(visible=False)
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history, \
gr.update(choices=[]), gr.update(visible=False)
else:
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
if "index.faiss" in os.listdir(vs_path):
@ -218,7 +223,12 @@ def change_chunk_conent(mode, label_conent, history):
def add_vs_name(vs_name, chatbot):
if vs_name in get_vs_list():
if vs_name is None or vs_name.strip() == "":
vs_status = "知识库名称不能为空,请重新填写知识库名称"
chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
visible=False), chatbot, gr.update(visible=False)
elif vs_name in get_vs_list():
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
chatbot = chatbot + [[None, vs_status]]
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
@ -257,6 +267,7 @@ def reinit_vector_store(vs_id, history):
def refresh_vs_list():
return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list())
def delete_file(vs_id, files_to_delete, chatbot):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
content_path = os.path.join(KB_ROOT_PATH, vs_id, "content")
@ -270,11 +281,11 @@ def delete_file(vs_id, files_to_delete, chatbot):
rested_files = local_doc_qa.list_file_from_vector_store(vs_path)
if "fail" in status:
vs_status = "文件删除失败。"
elif len(rested_files)>0:
elif len(rested_files) > 0:
vs_status = "文件删除成功。"
else:
vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。"
logger.info(",".join(files_to_delete)+vs_status)
logger.info(",".join(files_to_delete) + vs_status)
chatbot = chatbot + [[None, vs_status]]
return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot
@ -285,7 +296,8 @@ def delete_vs(vs_id, chatbot):
status = f"成功删除知识库{vs_id}"
logger.info(status)
chatbot = chatbot + [[None, status]]
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \
return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(
visible=True), \
gr.update(visible=False), chatbot, gr.update(visible=False)
except Exception as e:
logger.error(e)
@ -328,7 +340,8 @@ default_theme_args = dict(
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
vs_path, file_status, model_status = gr.State(
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(
""), gr.State(
model_status)
gr.Markdown(webui_title)
with gr.Tab("对话"):
@ -381,8 +394,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
load_folder_button = gr.Button("上传文件夹并加载知识库")
with gr.Tab("删除文件"):
files_to_delete = gr.CheckboxGroup(choices=[],
label="请从知识库已有文件中选择要删除的文件",
interactive=True)
label="请从知识库已有文件中选择要删除的文件",
interactive=True)
delete_file_button = gr.Button("从知识库中删除选中文件")
vs_refresh.click(fn=refresh_vs_list,
inputs=[],
@ -450,9 +463,9 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
with vs_setting:
vs_refresh = gr.Button("更新已有知识库选项")
select_vs_test = gr.Dropdown(get_vs_list(),
label="请选择要加载的知识库",
interactive=True,
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
label="请选择要加载的知识库",
interactive=True,
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
lines=1,
interactive=True,
@ -492,8 +505,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
inputs=[vs_name, chatbot],
outputs=[select_vs_test, vs_name, vs_add, file2vs, chatbot])
select_vs_test.change(fn=change_vs_name_input,
inputs=[select_vs_test, chatbot],
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
inputs=[select_vs_test, chatbot],
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
load_file_button.click(get_vector_store,
show_progress=True,
inputs=[select_vs_test, files, sentence_size, chatbot, vs_add, vs_add],

View File

@ -1,5 +1,5 @@
import streamlit as st
# from st_btn_select import st_btn_select
from streamlit_chatbox import st_chatbox
import tempfile
###### 从webui借用的代码 #####
###### 做了少量修改 #####
@ -23,6 +23,7 @@ def get_vs_list():
if not os.path.exists(KB_ROOT_PATH):
return lst_default
lst = os.listdir(KB_ROOT_PATH)
lst = [x for x in lst if os.path.isdir(os.path.join(KB_ROOT_PATH, x))]
if not lst:
return lst_default
lst.sort()
@ -31,7 +32,6 @@ def get_vs_list():
embedding_model_dict_list = list(embedding_model_dict.keys())
llm_model_dict_list = list(llm_model_dict.keys())
# flag_csv_logger = gr.CSVLogger()
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
@ -50,6 +50,9 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
history[-1][-1] += source
yield history, ""
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
local_doc_qa.top_k = vector_search_top_k
local_doc_qa.chunk_conent = chunk_conent
local_doc_qa.chunk_size = chunk_size
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
source = "\n\n"
@ -85,62 +88,16 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
yield history + [[query,
"请选择知识库后进行测试,当前未选择知识库。"]], ""
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming):
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": query, "history": history, "streaming": streaming})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][-1] = resp + (
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
yield history, ""
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
# flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'):
local_doc_qa = LocalDocQA()
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
args_dict.update(model=llm_model)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins,
embedding_model=embedding_model)
generator = local_doc_qa.llm.generatorAnswer("你好")
for answer_result in generator:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply)
except Exception as e:
logger.error(e)
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
if str(e) == "Unknown platform: darwin":
logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI具体方法请参考项目 README 中本地部署方法及常见问题:"
" https://github.com/imClumsyPanda/langchain-ChatGLM")
else:
logger.info(reply)
return local_doc_qa
# 暂未使用到,先保留
# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
# try:
# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
# llm_model_ins.history_len = llm_history_len
# local_doc_qa.init_cfg(llm_model=llm_model_ins,
# embedding_model=embedding_model,
# top_k=top_k)
# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
# logger.info(model_status)
# except Exception as e:
# logger.error(e)
# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
# logger.info(model_status)
# return history + [[None, model_status]]
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
@ -148,7 +105,8 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
filelist = []
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
if local_doc_qa.llm and local_doc_qa.embeddings:
qa = st.session_state.local_doc_qa
if qa.llm_model_chain and qa.embeddings:
if isinstance(files, list):
for file in files:
filename = os.path.split(file.name)[-1]
@ -156,10 +114,10 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte
KB_ROOT_PATH, vs_id, "content", filename))
filelist.append(os.path.join(
KB_ROOT_PATH, vs_id, "content", filename))
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(
vs_path, loaded_files = qa.init_knowledge_vector_store(
filelist, vs_path, sentence_size)
else:
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
sentence_size)
if len(loaded_files):
file_status = f"已添加 {''.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
@ -177,10 +135,7 @@ knowledge_base_test_mode_info = ("【注意】\n\n"
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
"""3. 使用"添加单条数据"添加文本至知识库时内容如未分段则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
"4. 单条内容长度建议设置在100-150左右。\n\n"
"5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中,"
"本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。"
"相关参数将在后续版本中支持本界面直接修改。")
"4. 单条内容长度建议设置在100-150左右。")
webui_title = """
@ -192,7 +147,7 @@ webui_title = """
###### todo #####
# 1. streamlit运行方式与一般web服务器不同使用模块是无法实现单例模式的所以shared和local_doc_qa都需要进行全局化处理。
# 目前已经实现了local_doc_qa的全局化后面要考虑shared
# 目前已经实现了local_doc_qa和shared.loaderCheckPoint的全局化
# 2. 当前local_doc_qa是一个全局变量一方面任何一个session对其做出修改都会影响所有session的对话;另一方面如何处理所有session的请求竞争也是问题。
# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
# 3. 目前只包含了get_answer对应的参数以后可以添加其他参数如temperature。
@ -201,25 +156,11 @@ webui_title = """
###### 配置项 #####
class ST_CONFIG:
user_bg_color = '#77ff77'
user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7'
robot_bg_color = '#ccccee'
robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0'
default_mode = '知识库问答'
defalut_kb = ''
default_mode = "知识库问答"
default_kb = ""
###### #####
class MsgType:
'''
目前仅支持文本类型的输入输出为以后多模态模型预留图像视频音频支持
'''
TEXT = 1
IMAGE = 2
VIDEO = 3
AUDIO = 4
class TempFile:
'''
为保持与get_vector_store的兼容性需要将streamlit上传文件转化为其可以接受的方式
@ -229,132 +170,54 @@ class TempFile:
self.name = path
def init_session():
st.session_state.setdefault('history', [])
# def get_query_params():
# '''
# 可以用url参数传递配置参数llm_model, embedding_model, kb, mode。
# 该参数将覆盖model_config中的配置。处于安全考虑目前只支持kb和mode
# 方便将固定的配置分享给特定的人。
# '''
# params = st.experimental_get_query_params()
# return {k: v[0] for k, v in params.items() if v}
def robot_say(msg, kb=''):
st.session_state['history'].append(
{'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb})
def user_say(msg):
st.session_state['history'].append(
{'is_user': True, 'type': MsgType.TEXT, 'content': msg})
def format_md(msg, is_user=False, bg_color='', margin='10%'):
'''
将文本消息格式化为markdown文本
'''
if is_user:
bg_color = bg_color or ST_CONFIG.user_bg_color
text = f'''
<div style="background:{bg_color};
margin-left:{margin};
word-break:break-all;
float:right;
padding:2%;
border-radius:2%;">
{msg}
</div>
'''
else:
bg_color = bg_color or ST_CONFIG.robot_bg_color
text = f'''
<div style="background:{bg_color};
margin-right:{margin};
word-break:break-all;
padding:2%;
border-radius:2%;">
{msg}
</div>
'''
return text
def message(msg,
is_user=False,
msg_type=MsgType.TEXT,
icon='',
bg_color='',
margin='10%',
kb='',
):
'''
渲染单条消息目前仅支持文本
'''
cols = st.columns([1, 10, 1])
empty = cols[1].empty()
if is_user:
icon = icon or ST_CONFIG.user_icon
bg_color = bg_color or ST_CONFIG.user_bg_color
cols[2].image(icon, width=40)
if msg_type == MsgType.TEXT:
text = format_md(msg, is_user, bg_color, margin)
empty.markdown(text, unsafe_allow_html=True)
else:
raise RuntimeError('only support text message now.')
else:
icon = icon or ST_CONFIG.robot_icon
bg_color = bg_color or ST_CONFIG.robot_bg_color
cols[0].image(icon, width=40)
if kb:
cols[0].write(f'({kb})')
if msg_type == MsgType.TEXT:
text = format_md(msg, is_user, bg_color, margin)
empty.markdown(text, unsafe_allow_html=True)
else:
raise RuntimeError('only support text message now.')
return empty
def output_messages(
user_bg_color='',
robot_bg_color='',
user_icon='',
robot_icon='',
):
with chat_box.container():
last_response = None
for msg in st.session_state['history']:
bg_color = user_bg_color if msg['is_user'] else robot_bg_color
icon = user_icon if msg['is_user'] else robot_icon
empty = message(msg['content'],
is_user=msg['is_user'],
icon=icon,
msg_type=msg['type'],
bg_color=bg_color,
kb=msg.get('kb', '')
)
if not msg['is_user']:
last_response = empty
return last_response
@st.cache_resource(show_spinner=False, max_entries=1)
def load_model(llm_model: str, embedding_model: str):
def load_model(
llm_model: str = LLM_MODEL,
embedding_model: str = EMBEDDING_MODEL,
use_ptuning_v2: bool = USE_PTUNING_V2,
):
'''
对应init_model利用streamlit cache避免模型重复加载
'''
local_doc_qa = init_model(llm_model, embedding_model)
robot_say('模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。\n请尽量不要刷新页面,以免模型出错或重复加载。')
local_doc_qa = LocalDocQA()
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
args_dict.update(model=llm_model)
if shared.loaderCheckPoint is None: # avoid checkpoint reloading when reinit model
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
# shared.loaderCheckPoint.model_name is different by no_remote_model.
# if it is not set properly error occurs when reinit llm model(issue#473).
# as no_remote_model is removed from model_config, need workaround to set it automaticlly.
local_model_path = llm_model_dict.get(llm_model, {}).get('local_model_path') or ''
no_remote_model = os.path.isdir(local_model_path)
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
llm_model_ins.history_len = LLM_HISTORY_LEN
try:
local_doc_qa.init_cfg(llm_model=llm_model_ins,
embedding_model=embedding_model)
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": "你好", "history": [], "streaming": False})
for answer_result in answer_result_stream_result['answer_result_stream']:
print(answer_result.llm_output)
reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话"""
logger.info(reply)
except Exception as e:
logger.error(e)
reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
if str(e) == "Unknown platform: darwin":
logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI具体方法请参考项目 README 中本地部署方法及常见问题:"
" https://github.com/imClumsyPanda/langchain-ChatGLM")
else:
logger.info(reply)
return local_doc_qa
# @st.cache_data
def answer(query, vs_path='', history=[], mode='', score_threshold=0,
vector_search_top_k=5, chunk_conent=True, chunk_size=100, qa=None
vector_search_top_k=5, chunk_conent=True, chunk_size=100
):
'''
对应get_answer--利用streamlit cache缓存相同问题的答案--
@ -363,48 +226,24 @@ def answer(query, vs_path='', history=[], mode='', score_threshold=0,
vector_search_top_k, chunk_conent, chunk_size)
def load_vector_store(
vs_id,
files,
sentence_size=100,
history=[],
one_conent=None,
one_content_segmentation=None,
):
return get_vector_store(
local_doc_qa,
vs_id,
files,
sentence_size,
history,
one_conent,
one_content_segmentation,
)
def use_kb_mode(m):
return m in ["知识库问答", "知识库测试"]
# main ui
st.set_page_config(webui_title, layout='wide')
init_session()
# params = get_query_params()
# llm_model = params.get('llm_model', LLM_MODEL)
# embedding_model = params.get('embedding_model', EMBEDDING_MODEL)
with st.spinner(f'正在加载模型({LLM_MODEL} + {EMBEDDING_MODEL}),请耐心等候...'):
local_doc_qa = load_model(LLM_MODEL, EMBEDDING_MODEL)
def use_kb_mode(m):
return m in ['知识库问答', '知识库测试']
chat_box = st_chatbox(greetings=["模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。"])
# 使用 help(st_chatbox) 查看自定义参数
# sidebar
modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试']
with st.sidebar:
def on_mode_change():
m = st.session_state.mode
robot_say(f'已切换到"{m}"模式')
chat_box.robot_say(f'已切换到"{m}"模式')
if m == '知识库测试':
robot_say(knowledge_base_test_mode_info)
chat_box.robot_say(knowledge_base_test_mode_info)
index = 0
try:
@ -414,7 +253,7 @@ with st.sidebar:
mode = st.selectbox('对话模式', modes, index,
on_change=on_mode_change, key='mode')
with st.expander('模型配置', '知识' not in mode):
with st.expander('模型配置', not use_kb_mode(mode)):
with st.form('model_config'):
index = 0
try:
@ -423,9 +262,8 @@ with st.sidebar:
pass
llm_model = st.selectbox('LLM模型', llm_model_dict_list, index)
no_remote_model = st.checkbox('加载本地模型', False)
use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False)
use_lora = st.checkbox('使用lora微调的权重', False)
try:
index = embedding_model_dict_list.index(EMBEDDING_MODEL)
except:
@ -435,42 +273,52 @@ with st.sidebar:
btn_load_model = st.form_submit_button('重新加载模型')
if btn_load_model:
local_doc_qa = load_model(llm_model, embedding_model)
local_doc_qa = load_model(llm_model, embedding_model, use_ptuning_v2)
if mode in ['知识库问答', '知识库测试']:
history_len = st.slider(
"LLM对话轮数", 1, 50, LLM_HISTORY_LEN)
if use_kb_mode(mode):
vs_list = get_vs_list()
vs_list.remove('新建知识库')
def on_new_kb():
name = st.session_state.kb_name
if name in vs_list:
st.error(f'名为“{name}”的知识库已存在。')
if not name:
st.sidebar.error(f'新建知识库名称不能为空!')
elif name in vs_list:
st.sidebar.error(f'名为“{name}”的知识库已存在。')
else:
vs_list.append(name)
st.session_state.vs_path = name
st.session_state.kb_name = ''
new_kb_dir = os.path.join(KB_ROOT_PATH, name)
if not os.path.exists(new_kb_dir):
os.makedirs(new_kb_dir)
st.sidebar.success(f'名为“{name}”的知识库创建成功,您可以开始添加文件。')
def on_vs_change():
robot_say(f'已加载知识库: {st.session_state.vs_path}')
chat_box.robot_say(f'已加载知识库: {st.session_state.vs_path}')
with st.expander('知识库配置', True):
cols = st.columns([12, 10])
kb_name = cols[0].text_input(
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed')
'新知识库名称', placeholder='新知识库名称', label_visibility='collapsed', key='kb_name')
cols[1].button('新建知识库', on_click=on_new_kb)
index = 0
try:
index = vs_list.index(ST_CONFIG.default_kb)
except:
pass
vs_path = st.selectbox(
'选择知识库', vs_list, on_change=on_vs_change, key='vs_path')
'选择知识库', vs_list, index, on_change=on_vs_change, key='vs_path')
st.text('')
score_threshold = st.slider(
'知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD)
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
history_len = st.slider(
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
local_doc_qa.llm.set_history_len(history_len)
chunk_conent = st.checkbox('启用上下文关联', False)
st.text('')
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE)
st.text('')
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
files = st.file_uploader('上传知识文件',
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
@ -483,56 +331,61 @@ with st.sidebar:
with open(file, 'wb') as fp:
fp.write(f.getvalue())
file_list.append(TempFile(file))
_, _, history = load_vector_store(
_, _, history = get_vector_store(
vs_path, file_list, sentence_size, [], None, None)
st.session_state.files = []
# main body
chat_box = st.empty()
# load model after params rendered
with st.spinner(f"正在加载模型({llm_model} + {embedding_model}),请耐心等候..."):
local_doc_qa = load_model(
llm_model,
embedding_model,
use_ptuning_v2,
)
local_doc_qa.llm_model_chain.history_len = history_len
if use_kb_mode(mode):
local_doc_qa.chunk_conent = chunk_conent
local_doc_qa.chunk_size = chunk_size
# local_doc_qa.llm_model_chain.temperature = temperature # 这样设置temperature似乎不起作用
st.session_state.local_doc_qa = local_doc_qa
with st.form('my_form', clear_on_submit=True):
# input form
with st.form("my_form", clear_on_submit=True):
cols = st.columns([8, 1])
question = cols[0].text_input(
question = cols[0].text_area(
'temp', key='input_question', label_visibility='collapsed')
def on_send():
q = st.session_state.input_question
if q:
user_say(q)
if cols[1].form_submit_button("发送"):
chat_box.user_say(question)
history = []
if mode == "LLM 对话":
chat_box.robot_say("正在思考...")
chat_box.output_messages()
for history, _ in answer(question,
history=[],
mode=mode):
chat_box.update_last_box_text(history[-1][-1])
elif use_kb_mode(mode):
chat_box.robot_say(f"正在查询 [{vs_path}] ...")
chat_box.output_messages()
for history, _ in answer(question,
vs_path=os.path.join(
KB_ROOT_PATH, vs_path, 'vector_store'),
history=[],
mode=mode,
score_threshold=score_threshold,
vector_search_top_k=top_k,
chunk_conent=chunk_conent,
chunk_size=chunk_size):
chat_box.update_last_box_text(history[-1][-1])
else:
chat_box.robot_say(f"正在执行Bing搜索...")
chat_box.output_messages()
for history, _ in answer(question,
history=[],
mode=mode):
chat_box.update_last_box_text(history[-1][-1])
if mode == 'LLM 对话':
robot_say('正在思考...')
last_response = output_messages()
for history, _ in answer(q,
history=[],
mode=mode):
last_response.markdown(
format_md(history[-1][-1], False),
unsafe_allow_html=True
)
elif use_kb_mode(mode):
robot_say('正在思考...', vs_path)
last_response = output_messages()
for history, _ in answer(q,
vs_path=os.path.join(
KB_ROOT_PATH, vs_path, "vector_store"),
history=[],
mode=mode,
score_threshold=score_threshold,
vector_search_top_k=top_k,
chunk_conent=chunk_conent,
chunk_size=chunk_size):
last_response.markdown(
format_md(history[-1][-1], False, 'ligreen'),
unsafe_allow_html=True
)
else:
robot_say('正在思考...')
last_response = output_messages()
st.session_state['history'][-1]['content'] = history[-1][-1]
submit = cols[1].form_submit_button('发送', on_click=on_send)
output_messages()
# st.write(st.session_state['history'])
# st.write(chat_box.history)
chat_box.output_messages()