diff --git a/.gitignore b/.gitignore index 3646195d..73ebd212 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,7 @@ embedding/* pyrightconfig.json loader/tmp_files -flagged/* \ No newline at end of file +flagged/* +ptuning-v2/*.json +ptuning-v2/*.bin + diff --git a/README.md b/README.md index 791bb55b..5ee0adea 100644 --- a/README.md +++ b/README.md @@ -23,13 +23,17 @@ 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。 +🐳 Docker镜像:registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 (感谢 @InkSong🌲 ) + +💻 运行方式:docker run -d -p 80:7860 --gpus all registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0  + 🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/imClumsyPanda/langchain-ChatGLM/langchain-ChatGLM) 📓 [ModelWhale 在线运行项目](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59) ## 变更日志 -参见 [变更日志](docs/CHANGELOG.md)。 +参见 [版本更新日志](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)。 ## 硬件需求 @@ -60,6 +64,23 @@ 本项目中默认选用的 Embedding 模型 [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese/tree/main) 约占用显存 3GB,也可修改为在 CPU 中运行。 +## Docker 整合包 +🐳 Docker镜像地址:`registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 `🌲 + +💻 一行命令运行: +```shell +docker run -d -p 80:7860 --gpus all registry.cn-beijing.aliyuncs.com/isafetech/chatmydata:1.0 +``` + +- 该版本镜像大小`25.2G`,使用[v0.1.16](https://github.com/imClumsyPanda/langchain-ChatGLM/releases/tag/v0.1.16),以`nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04`为基础镜像 +- 该版本内置两个`embedding`模型:`m3e-base`,`text2vec-large-chinese`,内置`fastchat+chatglm-6b` +- 该版本目标为方便一键部署使用,请确保您已经在Linux发行版上安装了NVIDIA驱动程序 +- 请注意,您不需要在主机系统上安装CUDA工具包,但需要安装`NVIDIA Driver`以及`NVIDIA Container Toolkit`,请参考[安装指南](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) +- 首次拉取和启动均需要一定时间,首次启动时请参照下图使用`docker logs -f `查看日志 +- 如遇到启动过程卡在`Waiting..`步骤,建议使用`docker exec -it bash`进入`/logs/`目录查看对应阶段日志 +![](img/docker_logs.png) + + ## Docker 部署 为了能让容器使用主机GPU资源,需要在主机上安装 [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit)。具体安装步骤如下: ```shell @@ -198,12 +219,17 @@ Web UI 可以实现如下功能: - [ ] 知识图谱/图数据库接入 - [ ] Agent 实现 - [x] 增加更多 LLM 模型支持 + - [x] [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) - [x] [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) - [x] [THUDM/chatglm-6b-int8](https://huggingface.co/THUDM/chatglm-6b-int8) - [x] [THUDM/chatglm-6b-int4](https://huggingface.co/THUDM/chatglm-6b-int4) - [x] [THUDM/chatglm-6b-int4-qe](https://huggingface.co/THUDM/chatglm-6b-int4-qe) - [x] [ClueAI/ChatYuan-large-v2](https://huggingface.co/ClueAI/ChatYuan-large-v2) - [x] [fnlp/moss-moon-003-sft](https://huggingface.co/fnlp/moss-moon-003-sft) + - [x] [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1) + - [x] [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b) + - [x] [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) + - [x] [lmsys/vicuna-13b-delta-v1.1](https://huggingface.co/lmsys/vicuna-13b-delta-v1.1) - [x] 支持通过调用 [fastchat](https://github.com/lm-sys/FastChat) api 调用 llm - [x] 增加更多 Embedding 模型支持 - [x] [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh) @@ -221,7 +247,7 @@ Web UI 可以实现如下功能: - [x] 选择知识库开始问答 - [x] 上传文件/文件夹至知识库 - [x] 知识库测试 - - [ ] 删除知识库中文件 + - [x] 删除知识库中文件 - [x] 支持搜索引擎问答 - [ ] 增加 API 支持 - [x] 利用 fastapi 实现 API 部署方式 @@ -229,7 +255,7 @@ Web UI 可以实现如下功能: - [x] VUE 前端 ## 项目交流群 -二维码 +二维码 🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 diff --git a/api.py b/api.py index 768a3f73..bc2fe633 100644 --- a/api.py +++ b/api.py @@ -1,10 +1,11 @@ +#encoding:utf-8 import argparse import json import os import shutil from typing import List, Optional import urllib - +import asyncio import nltk import pydantic import uvicorn @@ -55,7 +56,7 @@ class ListDocsResponse(BaseResponse): class ChatMessage(BaseModel): question: str = pydantic.Field(..., description="Question text") response: str = pydantic.Field(..., description="Response text") - history: List[List[str]] = pydantic.Field(..., description="History text") + history: List[List[Optional[str]]] = pydantic.Field(..., description="History text") source_documents: List[str] = pydantic.Field( ..., description="List of source documents and their scores" ) @@ -80,23 +81,37 @@ class ChatMessage(BaseModel): } -def get_folder_path(local_doc_id: str): - return os.path.join(KB_ROOT_PATH, local_doc_id, "content") +def get_kb_path(local_doc_id: str): + return os.path.join(KB_ROOT_PATH, local_doc_id) + + +def get_doc_path(local_doc_id: str): + return os.path.join(get_kb_path(local_doc_id), "content") def get_vs_path(local_doc_id: str): - return os.path.join(KB_ROOT_PATH, local_doc_id, "vector_store") + return os.path.join(get_kb_path(local_doc_id), "vector_store") def get_file_path(local_doc_id: str, doc_name: str): - return os.path.join(KB_ROOT_PATH, local_doc_id, "content", doc_name) + return os.path.join(get_doc_path(local_doc_id), doc_name) + + +def validate_kb_name(knowledge_base_id: str) -> bool: + # 检查是否包含预期外的字符或路径攻击关键字 + if "../" in knowledge_base_id: + return False + return True async def upload_file( file: UploadFile = File(description="A single binary file"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), ): - saved_path = get_folder_path(knowledge_base_id) + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me", data=[]) + + saved_path = get_doc_path(knowledge_base_id) if not os.path.exists(saved_path): os.makedirs(saved_path) @@ -126,21 +141,25 @@ async def upload_files( ], knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), ): - saved_path = get_folder_path(knowledge_base_id) + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me", data=[]) + + saved_path = get_doc_path(knowledge_base_id) if not os.path.exists(saved_path): os.makedirs(saved_path) filelist = [] for file in files: file_content = '' file_path = os.path.join(saved_path, file.filename) - file_content = file.file.read() + file_content = await file.read() if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content): continue - with open(file_path, "ab+") as f: + with open(file_path, "wb") as f: f.write(file_content) filelist.append(file_path) if filelist: - vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, get_vs_path(knowledge_base_id)) + vs_path = get_vs_path(knowledge_base_id) + vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path) if len(loaded_files): file_status = f"documents {', '.join([os.path.split(i)[-1] for i in loaded_files])} upload success" return BaseResponse(code=200, msg=file_status) @@ -164,16 +183,24 @@ async def list_kbs(): async def list_docs( - knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1") + knowledge_base_id: str = Query(..., description="Knowledge Base Name", example="kb1") ): - local_doc_folder = get_folder_path(knowledge_base_id) + if not validate_kb_name(knowledge_base_id): + return ListDocsResponse(code=403, msg="Don't attack me", data=[]) + + knowledge_base_id = urllib.parse.unquote(knowledge_base_id) + kb_path = get_kb_path(knowledge_base_id) + local_doc_folder = get_doc_path(knowledge_base_id) + if not os.path.exists(kb_path): + return ListDocsResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found", data=[]) if not os.path.exists(local_doc_folder): - return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} - all_doc_names = [ - doc - for doc in os.listdir(local_doc_folder) - if os.path.isfile(os.path.join(local_doc_folder, doc)) - ] + all_doc_names = [] + else: + all_doc_names = [ + doc + for doc in os.listdir(local_doc_folder) + if os.path.isfile(os.path.join(local_doc_folder, doc)) + ] return ListDocsResponse(data=all_doc_names) @@ -182,11 +209,15 @@ async def delete_kb( description="Knowledge Base Name", example="kb1"), ): + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me") + # TODO: 确认是否支持批量删除知识库 knowledge_base_id = urllib.parse.unquote(knowledge_base_id) - if not os.path.exists(get_folder_path(knowledge_base_id)): - return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} - shutil.rmtree(get_folder_path(knowledge_base_id)) + kb_path = get_kb_path(knowledge_base_id) + if not os.path.exists(kb_path): + return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found") + shutil.rmtree(kb_path) return BaseResponse(code=200, msg=f"Knowledge Base {knowledge_base_id} delete success") @@ -195,27 +226,30 @@ async def delete_doc( description="Knowledge Base Name", example="kb1"), doc_name: str = Query( - None, description="doc name", example="doc_name_1.pdf" + ..., description="doc name", example="doc_name_1.pdf" ), ): + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me") + knowledge_base_id = urllib.parse.unquote(knowledge_base_id) - if not os.path.exists(get_folder_path(knowledge_base_id)): - return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} + if not os.path.exists(get_kb_path(knowledge_base_id)): + return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found") doc_path = get_file_path(knowledge_base_id, doc_name) if os.path.exists(doc_path): os.remove(doc_path) remain_docs = await list_docs(knowledge_base_id) if len(remain_docs.data) == 0: - shutil.rmtree(get_folder_path(knowledge_base_id), ignore_errors=True) + shutil.rmtree(get_kb_path(knowledge_base_id), ignore_errors=True) return BaseResponse(code=200, msg=f"document {doc_name} delete success") else: status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id)) if "success" in status: return BaseResponse(code=200, msg=f"document {doc_name} delete success") else: - return BaseResponse(code=1, msg=f"document {doc_name} delete fail") + return BaseResponse(code=500, msg=f"document {doc_name} delete fail") else: - return BaseResponse(code=1, msg=f"document {doc_name} not found") + return BaseResponse(code=404, msg=f"document {doc_name} not found") async def update_doc( @@ -223,23 +257,26 @@ async def update_doc( description="知识库名", example="kb1"), old_doc: str = Query( - None, description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf" + ..., description="待删除文件名,已存储在知识库中", example="doc_name_1.pdf" ), new_doc: UploadFile = File(description="待上传文件"), ): + if not validate_kb_name(knowledge_base_id): + return BaseResponse(code=403, msg="Don't attack me") + knowledge_base_id = urllib.parse.unquote(knowledge_base_id) - if not os.path.exists(get_folder_path(knowledge_base_id)): - return {"code": 1, "msg": f"Knowledge base {knowledge_base_id} not found"} + if not os.path.exists(get_kb_path(knowledge_base_id)): + return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found") doc_path = get_file_path(knowledge_base_id, old_doc) if not os.path.exists(doc_path): - return BaseResponse(code=1, msg=f"document {old_doc} not found") + return BaseResponse(code=404, msg=f"document {old_doc} not found") else: os.remove(doc_path) delete_status = local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_id)) if "fail" in delete_status: - return BaseResponse(code=1, msg=f"document {old_doc} delete failed") + return BaseResponse(code=500, msg=f"document {old_doc} delete failed") else: - saved_path = get_folder_path(knowledge_base_id) + saved_path = get_doc_path(knowledge_base_id) if not os.path.exists(saved_path): os.makedirs(saved_path) @@ -267,8 +304,8 @@ async def update_doc( async def local_doc_chat( knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), question: str = Body(..., description="Question", example="工伤保险是什么?"), - stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"), - history: List[List[str]] = Body( + streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"), + history: List[List[Optional[str]]] = Body( [], description="History of previous questions and answers", example=[ @@ -281,7 +318,7 @@ async def local_doc_chat( ): vs_path = get_vs_path(knowledge_base_id) if not os.path.exists(vs_path): - # return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found") + # return BaseResponse(code=404, msg=f"Knowledge base {knowledge_base_id} not found") return ChatMessage( question=question, response=f"Knowledge base {knowledge_base_id} not found", @@ -289,7 +326,7 @@ async def local_doc_chat( source_documents=[], ) else: - if (stream): + if (streaming): def generate_answer (): last_print_len = 0 for resp, next_history in local_doc_qa.get_knowledge_based_answer( @@ -300,7 +337,7 @@ async def local_doc_chat( return StreamingResponse(generate_answer()) else: - for resp, next_history in local_doc_qa.get_knowledge_based_answer( + for resp, history in local_doc_qa.get_knowledge_based_answer( query=question, vs_path=vs_path, chat_history=history, streaming=True ): pass @@ -314,14 +351,14 @@ async def local_doc_chat( return ChatMessage( question=question, response=resp["result"], - history=next_history, + history=history, source_documents=source_documents, ) async def bing_search_chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), - history: Optional[List[List[str]]] = Body( + history: Optional[List[List[Optional[str]]]] = Body( [], description="History of previous questions and answers", example=[ @@ -351,8 +388,8 @@ async def bing_search_chat( async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), - stream: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"), - history: List[List[str]] = Body( + streaming: bool = Body(False, description="是否开启流式输出,默认false,有些模型可能不支持。"), + history: List[List[Optional[str]]] = Body( [], description="History of previous questions and answers", example=[ @@ -363,19 +400,20 @@ async def chat( ], ), ): - - if (stream): + if (streaming): def generate_answer (): last_print_len = 0 - for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history, - streaming=True): + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": question, "history": history, "streaming": True}) + for answer_result in answer_result_stream_result['answer_result_stream']: yield answer_result.llm_output["answer"][last_print_len:] last_print_len = len(answer_result.llm_output["answer"]) return StreamingResponse(generate_answer()) else: - for answer_result in local_doc_qa.llm.generatorAnswer(prompt=question, history=history, - streaming=True): + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": question, "history": history, "streaming": True}) + for answer_result in answer_result_stream_result['answer_result_stream']: resp = answer_result.llm_output["answer"] history = answer_result.history pass @@ -386,9 +424,22 @@ async def chat( history=history, source_documents=[], ) + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": question, "history": history, "streaming": True}) + + for answer_result in answer_result_stream_result['answer_result_stream']: + resp = answer_result.llm_output["answer"] + history = answer_result.history + pass + return ChatMessage( + question=question, + response=resp, + history=history, + source_documents=[], + ) -async def stream_chat(websocket: WebSocket, knowledge_base_id: str): +async def stream_chat(websocket: WebSocket): await websocket.accept() turn = 1 while True: @@ -408,6 +459,7 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): for resp, history in local_doc_qa.get_knowledge_based_answer( query=question, vs_path=vs_path, chat_history=history, streaming=True ): + await asyncio.sleep(0) await websocket.send_text(resp["result"][last_print_len:]) last_print_len = len(resp["result"]) @@ -430,17 +482,51 @@ async def stream_chat(websocket: WebSocket, knowledge_base_id: str): ) turn += 1 +async def stream_chat_bing(websocket: WebSocket): + """ + 基于bing搜索的流式问答 + """ + await websocket.accept() + turn = 1 + while True: + input_json = await websocket.receive_json() + question, history = input_json["question"], input_json["history"] + + await websocket.send_json({"question": question, "turn": turn, "flag": "start"}) + + last_print_len = 0 + for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history, streaming=True): + await websocket.send_text(resp["result"][last_print_len:]) + last_print_len = len(resp["result"]) + + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + + await websocket.send_text( + json.dumps( + { + "question": question, + "turn": turn, + "flag": "end", + "sources_documents": source_documents, + }, + ensure_ascii=False, + ) + ) + turn += 1 async def document(): return RedirectResponse(url="/docs") -def api_start(host, port): +def api_start(host, port, **kwargs): global app global local_doc_qa llm_model_ins = shared.loaderLLM() - llm_model_ins.set_history_len(LLM_HISTORY_LEN) app = FastAPI() # Add CORS middleware to allow all origins @@ -454,21 +540,28 @@ def api_start(host, port): allow_methods=["*"], allow_headers=["*"], ) - app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat) + # 修改了stream_chat的接口,直接通过ws://localhost:7861/local_doc_qa/stream_chat建立连接,在请求体中选择knowledge_base_id + app.websocket("/local_doc_qa/stream_chat")(stream_chat) - app.get("/", response_model=BaseResponse)(document) + app.get("/", response_model=BaseResponse, summary="swagger 文档")(document) - app.post("/chat", response_model=ChatMessage)(chat) + # 增加基于bing搜索的流式问答 + # 需要说明的是,如果想测试websocket的流式问答,需要使用支持websocket的测试工具,如postman,insomnia + # 强烈推荐开源的insomnia + # 在测试时选择new websocket request,并将url的协议改为ws,如ws://localhost:7861/local_doc_qa/stream_chat_bing + app.websocket("/local_doc_qa/stream_chat_bing")(stream_chat_bing) - app.post("/local_doc_qa/upload_file", response_model=BaseResponse)(upload_file) - app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files) - app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat) - app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage)(bing_search_chat) - app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse)(list_kbs) - app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) - app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse)(delete_kb) - app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_doc) - app.post("/local_doc_qa/update_file", response_model=BaseResponse)(update_doc) + app.post("/chat", response_model=ChatMessage, summary="与模型对话")(chat) + + app.post("/local_doc_qa/upload_file", response_model=BaseResponse, summary="上传文件到知识库")(upload_file) + app.post("/local_doc_qa/upload_files", response_model=BaseResponse, summary="批量上传文件到知识库")(upload_files) + app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage, summary="与知识库对话")(local_doc_chat) + app.post("/local_doc_qa/bing_search_chat", response_model=ChatMessage, summary="与必应搜索对话")(bing_search_chat) + app.get("/local_doc_qa/list_knowledge_base", response_model=ListDocsResponse, summary="获取知识库列表")(list_kbs) + app.get("/local_doc_qa/list_files", response_model=ListDocsResponse, summary="获取知识库内的文件列表")(list_docs) + app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库")(delete_kb) + app.delete("/local_doc_qa/delete_file", response_model=BaseResponse, summary="删除知识库内的文件")(delete_doc) + app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(update_doc) local_doc_qa = LocalDocQA() local_doc_qa.init_cfg( @@ -477,15 +570,21 @@ def api_start(host, port): embedding_device=EMBEDDING_DEVICE, top_k=VECTOR_SEARCH_TOP_K, ) - uvicorn.run(app, host=host, port=port) + if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): + uvicorn.run(app, host=host, port=port, ssl_keyfile=kwargs.get("ssl_keyfile"), + ssl_certfile=kwargs.get("ssl_certfile")) + else: + uvicorn.run(app, host=host, port=port) if __name__ == "__main__": parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7861) + parser.add_argument("--ssl_keyfile", type=str) + parser.add_argument("--ssl_certfile", type=str) # 初始化消息 args = None args = parser.parse_args() args_dict = vars(args) shared.loaderCheckPoint = LoaderCheckPoint(args_dict) - api_start(args.host, args.port) + api_start(args.host, args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index fe700662..6085dfc0 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -8,7 +8,6 @@ from typing import List from utils import torch_gc from tqdm import tqdm from pypinyin import lazy_pinyin -from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader from models.base import (BaseAnswer, AnswerResult) from models.loader.args import parser @@ -18,6 +17,7 @@ from agent import bing_search from langchain.docstore.document import Document from functools import lru_cache from textsplitter.zh_title_enhance import zh_title_enhance +from langchain.chains.base import Chain # patch HuggingFaceEmbeddings to make it hashable @@ -58,6 +58,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None): def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE): + if filepath.lower().endswith(".md"): loader = UnstructuredFileLoader(filepath, mode="elements") docs = loader.load() @@ -66,10 +67,14 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = loader.load_and_split(textsplitter) elif filepath.lower().endswith(".pdf"): + # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x + from loader import UnstructuredPaddlePDFLoader loader = UnstructuredPaddlePDFLoader(filepath) textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) docs = loader.load_and_split(textsplitter) elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"): + # 暂且将paddle相关的loader改为动态加载,可以在不上传pdf/image知识文件的前提下使用protobuf=4.x + from loader import UnstructuredPaddleImageLoader loader = UnstructuredPaddleImageLoader(filepath, mode="elements") textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = loader.load_and_split(text_splitter=textsplitter) @@ -119,7 +124,7 @@ def search_result2docs(search_results): class LocalDocQA: - llm: BaseAnswer = None + llm_model_chain: Chain = None embeddings: object = None top_k: int = VECTOR_SEARCH_TOP_K chunk_size: int = CHUNK_SIZE @@ -129,10 +134,10 @@ class LocalDocQA: def init_cfg(self, embedding_model: str = EMBEDDING_MODEL, embedding_device=EMBEDDING_DEVICE, - llm_model: BaseAnswer = None, + llm_model: Chain = None, top_k=VECTOR_SEARCH_TOP_K, ): - self.llm = llm_model + self.llm_model_chain = llm_model self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model], model_kwargs={'device': embedding_device}) self.top_k = top_k @@ -200,6 +205,7 @@ class LocalDocQA: return vs_path, loaded_files else: logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。") + return None, loaded_files def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size): @@ -235,8 +241,10 @@ class LocalDocQA: else: prompt = query - for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, - streaming=streaming): + answer_result_stream_result = self.llm_model_chain( + {"prompt": prompt, "history": chat_history, "streaming": streaming}) + + for answer_result in answer_result_stream_result['answer_result_stream']: resp = answer_result.llm_output["answer"] history = answer_result.history history[-1][0] = query @@ -275,8 +283,10 @@ class LocalDocQA: result_docs = search_result2docs(results) prompt = generate_prompt(result_docs, query) - for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, - streaming=streaming): + answer_result_stream_result = self.llm_model_chain( + {"prompt": prompt, "history": chat_history, "streaming": streaming}) + + for answer_result in answer_result_stream_result['answer_result_stream']: resp = answer_result.llm_output["answer"] history = answer_result.history history[-1][0] = query @@ -295,7 +305,7 @@ class LocalDocQA: def update_file_from_vector_store(self, filepath: str or List[str], vs_path, - docs: List[Document],): + docs: List[Document], ): vector_store = load_vector_store(vs_path, self.embeddings) status = vector_store.update_doc(filepath, docs) return status @@ -319,7 +329,6 @@ if __name__ == "__main__": args_dict = vars(args) shared.loaderCheckPoint = LoaderCheckPoint(args_dict) llm_model_ins = shared.loaderLLM() - llm_model_ins.set_history_len(LLM_HISTORY_LEN) local_doc_qa = LocalDocQA() local_doc_qa.init_cfg(llm_model=llm_model_ins) diff --git a/chains/modules/embeddings.py b/chains/modules/embeddings.py deleted file mode 100644 index 3abeddff..00000000 --- a/chains/modules/embeddings.py +++ /dev/null @@ -1,34 +0,0 @@ -from langchain.embeddings.huggingface import HuggingFaceEmbeddings - -from typing import Any, List - - -class MyEmbeddings(HuggingFaceEmbeddings): - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Compute doc embeddings using a HuggingFace transformer model. - - Args: - texts: The list of texts to embed. - - Returns: - List of embeddings, one for each text. - """ - texts = list(map(lambda x: x.replace("\n", " "), texts)) - embeddings = self.client.encode(texts, normalize_embeddings=True) - return embeddings.tolist() - - def embed_query(self, text: str) -> List[float]: - """Compute query embeddings using a HuggingFace transformer model. - - Args: - text: The text to embed. - - Returns: - Embeddings for the text. - """ - text = text.replace("\n", " ") - embedding = self.client.encode(text, normalize_embeddings=True) - return embedding.tolist() diff --git a/chains/modules/vectorstores.py b/chains/modules/vectorstores.py deleted file mode 100644 index da89775a..00000000 --- a/chains/modules/vectorstores.py +++ /dev/null @@ -1,121 +0,0 @@ -from langchain.vectorstores import FAISS -from typing import Any, Callable, List, Optional, Tuple, Dict -from langchain.docstore.document import Document -from langchain.docstore.base import Docstore - -from langchain.vectorstores.utils import maximal_marginal_relevance -from langchain.embeddings.base import Embeddings -import uuid -from langchain.docstore.in_memory import InMemoryDocstore - -import numpy as np - -def dependable_faiss_import() -> Any: - """Import faiss if available, otherwise raise error.""" - try: - import faiss - except ImportError: - raise ValueError( - "Could not import faiss python package. " - "Please install it with `pip install faiss` " - "or `pip install faiss-cpu` (depending on Python version)." - ) - return faiss - -class FAISSVS(FAISS): - def __init__(self, - embedding_function: Callable[..., Any], - index: Any, - docstore: Docstore, - index_to_docstore_id: Dict[int, str]): - super().__init__(embedding_function, index, docstore, index_to_docstore_id) - - def max_marginal_relevance_search_by_vector( - self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - - Returns: - List of Documents with scores selected by maximal marginal relevance. - """ - scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k) - # -1 happens when not enough docs are returned. - embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] - mmr_selected = maximal_marginal_relevance( - np.array([embedding], dtype=np.float32), embeddings, k=k - ) - selected_indices = [indices[0][i] for i in mmr_selected] - selected_scores = [scores[0][i] for i in mmr_selected] - docs = [] - for i, score in zip(selected_indices, selected_scores): - if i == -1: - # This happens when not enough docs are returned. - continue - _id = self.index_to_docstore_id[i] - doc = self.docstore.search(_id) - if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {_id}, got {doc}") - docs.append((doc, score)) - return docs - - def max_marginal_relevance_search( - self, - query: str, - k: int = 4, - fetch_k: int = 20, - **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Return docs selected using the maximal marginal relevance. - - Maximal marginal relevance optimizes for similarity to query AND diversity - among selected documents. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - fetch_k: Number of Documents to fetch to pass to MMR algorithm. - - Returns: - List of Documents with scores selected by maximal marginal relevance. - """ - embedding = self.embedding_function(query) - docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k) - return docs - - @classmethod - def __from( - cls, - texts: List[str], - embeddings: List[List[float]], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - **kwargs: Any, - ) -> FAISS: - faiss = dependable_faiss_import() - index = faiss.IndexFlatIP(len(embeddings[0])) - index.add(np.array(embeddings, dtype=np.float32)) - - # # my code, for speeding up search - # quantizer = faiss.IndexFlatL2(len(embeddings[0])) - # index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100) - # index.train(np.array(embeddings, dtype=np.float32)) - # index.add(np.array(embeddings, dtype=np.float32)) - - documents = [] - for i, text in enumerate(texts): - metadata = metadatas[i] if metadatas else {} - documents.append(Document(page_content=text, metadata=metadata)) - index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))} - docstore = InMemoryDocstore( - {index_to_id[i]: doc for i, doc in enumerate(documents)} - ) - return cls(embedding.embed_query, index, docstore, index_to_id) - diff --git a/cli.py b/cli.py index 3d9c2518..bb201333 100644 --- a/cli.py +++ b/cli.py @@ -42,7 +42,9 @@ def start(): @start.command(name="api", context_settings=dict(help_option_names=['-h', '--help'])) @click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.') @click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.') -def start_api(ip, port): +@click.option('-k', '--ssl_keyfile', type=int, help='enable api https/wss service, specify the ssl keyfile path.') +@click.option('-c', '--ssl_certfile', type=int, help='enable api https/wss service, specify the ssl certificate file path.') +def start_api(ip, port, **kwargs): # 调用api_start之前需要先loadCheckPoint,并传入加载检查点的参数, # 理论上可以用click包进行包装,但过于繁琐,改动较大, # 此处仍用parser包,并以models.loader.args.DEFAULT_ARGS的参数为默认参数 @@ -51,7 +53,7 @@ def start_api(ip, port): from models.loader import LoaderCheckPoint from models.loader.args import DEFAULT_ARGS shared.loaderCheckPoint = LoaderCheckPoint(DEFAULT_ARGS) - api_start(host=ip, port=port) + api_start(host=ip, port=port, **kwargs) # # 通过cli.py调用cli_demo时需要在cli.py里初始化模型,否则会报错: # langchain-ChatGLM: error: unrecognized arguments: start cli diff --git a/cli_demo.py b/cli_demo.py index 938ebb33..a445e144 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -23,11 +23,33 @@ def main(): top_k=VECTOR_SEARCH_TOP_K) vs_path = None while not vs_path: + print("注意输入的路径是完整的文件路径,例如knowledge_base/`knowledge_base_id`/content/file.md,多个路径用英文逗号分割") filepath = input("Input your local knowledge file path 请输入本地知识文件路径:") + # 判断 filepath 是否为空,如果为空的话,重新让用户输入,防止用户误触回车 if not filepath: continue - vs_path, _ = local_doc_qa.init_knowledge_vector_store(filepath) + + # 支持加载多个文件 + filepath = filepath.split(",") + # filepath错误的返回为None, 如果直接用原先的vs_path,_ = local_doc_qa.init_knowledge_vector_store(filepath) + # 会直接导致TypeError: cannot unpack non-iterable NoneType object而使得程序直接退出 + # 因此需要先加一层判断,保证程序能继续运行 + temp,loaded_files = local_doc_qa.init_knowledge_vector_store(filepath) + if temp is not None: + vs_path = temp + # 如果loaded_files和len(filepath)不一致,则说明部分文件没有加载成功 + # 如果是路径错误,则应该支持重新加载 + if len(loaded_files) != len(filepath): + reload_flag = eval(input("部分文件加载失败,若提示路径不存在,可重新加载,是否重新加载,输入True或False: ")) + if reload_flag: + vs_path = None + continue + + print(f"the loaded vs_path is 加载的vs_path为: {vs_path}") + else: + print("load file failed, re-input your local knowledge file path 请重新输入本地知识文件路径") + history = [] while True: query = input("Input your question 请输入问题:") diff --git a/docs/FAQ.md b/docs/FAQ.md index f7124770..ccc0f254 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -177,3 +177,22 @@ download_with_progressbar(url, tmp_path) Q14 调用api中的 `bing_search_chat`接口时,报出 `Failed to establish a new connection: [Errno 110] Connection timed out` 这是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG--! + +--- + +Q15 加载chatglm-6b-int8或chatglm-6b-int4抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients` + +疑为chatglm的quantization的问题或torch版本差异问题,针对已经变为Parameter的torch.zeros矩阵也执行Parameter操作,从而抛出 `RuntimeError: Only Tensors of floating point andcomplex dtype can require gradients`。解决办法是在chatglm-项目的原始文件中的quantization.py文件374行改为: + +``` + try: + self.weight =Parameter(self.weight.to(kwargs["device"]), requires_grad=False) + except Exception as e: + pass +``` + + 如果上述方式不起作用,则在.cache/hugggingface/modules/目录下针对chatglm项目的原始文件中的quantization.py文件执行上述操作,若软链接不止一个,按照错误提示选择正确的路径。 + +注:虽然模型可以顺利加载但在cpu上仍存在推理失败的可能:即针对每个问题,模型一直输出gugugugu。 + + 因此,最好不要试图用cpu加载量化模型,原因可能是目前python主流量化包的量化操作是在gpu上执行的,会天然地存在gap。 diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 83e52ab1..2682c7b7 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -44,4 +44,12 @@ $ pip install -r requirements.txt $ python loader/image_loader.py ``` + 注:使用 `langchain.document_loaders.UnstructuredFileLoader` 进行非结构化文件接入时,可能需要依据文档进行其他依赖包的安装,请参考 [langchain 文档](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/unstructured_file.html)。 + +## llama-cpp模型调用的说明 + +1. 首先从huggingface hub中下载对应的模型,如 [https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/) 的 [ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin),建议使用huggingface_hub库的snapshot_download下载。 +2. 将下载的模型重命名。通过huggingface_hub下载的模型会被重命名为随机序列,因此需要重命名为原始文件名,如[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)。 +3. 基于下载模型的ggml的加载时间,推测对应的llama-cpp版本,下载对应的llama-cpp-python库的wheel文件,实测[ggml-vic13b-q5_1.bin](https://huggingface.co/vicuna/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin)与llama-cpp-python库兼容,然后手动安装wheel文件。 +4. 将下载的模型信息写入configs/model_config.py文件里 `llm_model_dict`中,注意保证参数的兼容性,一些参数组合可能会报错. diff --git a/docs/启动API服务.md b/docs/启动API服务.md new file mode 100644 index 00000000..aa816dd6 --- /dev/null +++ b/docs/启动API服务.md @@ -0,0 +1,37 @@ +# 启动API服务 + +## 通过py文件启动 +可以通过直接执行`api.py`文件启动API服务,默认以ip:0.0.0.0和port:7861启动http和ws服务。 +```shell +python api.py +``` +同时,启动时支持StartOption所列的模型加载参数,同时还支持IP和端口设置。 +```shell +python api.py --model-name chatglm-6b-int8 --port 7862 +``` + +## 通过cli.bat/cli.sh启动 +也可以通过命令行控制文件继续启动。 +```shell +cli.sh api --help +``` +其他可设置参数和上述py文件启动方式相同。 + + +# 以https、wss启动API服务 +## 本地创建ssl相关证书文件 +如果没有正式签发的CA证书,可以[安装mkcert](https://github.com/FiloSottile/mkcert#installation)工具, 然后用如下指令生成本地CA证书: +```shell +mkcert -install +mkcert api.example.com 47.123.123.123 localhost 127.0.0.1 ::1 +``` +默认回车保存在当前目录下,会有以生成指令第一个域名命名为前缀命名的两个pem文件。 + +附带两个文件参数启动即可。 +````shell +python api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem + +./cli.sh api --port 7862 --ssl_keyfile api.example.com+4-key.pem --ssl_certfile api.example.com+4.pem +```` + +此外可以通过前置Nginx转发实现类似效果,可另行查阅相关资料。 \ No newline at end of file diff --git a/img/docker_logs.png b/img/docker_logs.png new file mode 100644 index 00000000..03829582 Binary files /dev/null and b/img/docker_logs.png differ diff --git a/img/qr_code_32.jpg b/img/qr_code_32.jpg deleted file mode 100644 index 7f90e407..00000000 Binary files a/img/qr_code_32.jpg and /dev/null differ diff --git a/img/qr_code_33.jpg b/img/qr_code_33.jpg deleted file mode 100644 index c9d9bde2..00000000 Binary files a/img/qr_code_33.jpg and /dev/null differ diff --git a/img/qr_code_45.jpg b/img/qr_code_45.jpg new file mode 100644 index 00000000..ad253c87 Binary files /dev/null and b/img/qr_code_45.jpg differ diff --git a/loader/image_loader.py b/loader/image_loader.py index ec32459c..4ac4c51c 100644 --- a/loader/image_loader.py +++ b/loader/image_loader.py @@ -5,9 +5,6 @@ from langchain.document_loaders.unstructured import UnstructuredFileLoader from paddleocr import PaddleOCR import os import nltk -from configs.model_config import NLTK_DATA_PATH - -nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path class UnstructuredPaddleImageLoader(UnstructuredFileLoader): """Loader that uses unstructured to load image files, such as PNGs and JPGs.""" @@ -35,6 +32,10 @@ class UnstructuredPaddleImageLoader(UnstructuredFileLoader): if __name__ == "__main__": import sys sys.path.append(os.path.dirname(os.path.dirname(__file__))) + + from configs.model_config import NLTK_DATA_PATH + nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base", "samples", "content", "test.jpg") loader = UnstructuredPaddleImageLoader(filepath, mode="elements") docs = loader.load() diff --git a/models/__init__.py b/models/__init__.py index 4d75c87a..533eeaf9 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,4 +1,4 @@ -from .chatglm_llm import ChatGLM -from .llama_llm import LLamaLLM -from .moss_llm import MOSSLLM -from .fastchat_openai_llm import FastChatOpenAILLM +from .chatglm_llm import ChatGLMLLMChain +from .llama_llm import LLamaLLMChain +from .fastchat_openai_llm import FastChatOpenAILLMChain +from .moss_llm import MOSSLLMChain diff --git a/models/base/__init__.py b/models/base/__init__.py index 0b240351..cf7fc678 100644 --- a/models/base/__init__.py +++ b/models/base/__init__.py @@ -1,13 +1,15 @@ from models.base.base import ( AnswerResult, - BaseAnswer -) + BaseAnswer, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) from models.base.remote_rpc_model import ( RemoteRpcModel ) - __all__ = [ "AnswerResult", "BaseAnswer", "RemoteRpcModel", + "AnswerResultStream", + "AnswerResultQueueSentinelTokenListenerQueue" ] diff --git a/models/base/base.py b/models/base/base.py index b0fb4981..c6674c9c 100644 --- a/models/base/base.py +++ b/models/base/base.py @@ -1,16 +1,30 @@ from abc import ABC, abstractmethod -from typing import Optional, List +from typing import Any, Dict, List, Optional, Generator import traceback from collections import deque from queue import Queue from threading import Thread - +from langchain.callbacks.manager import CallbackManagerForChainRun +from models.loader import LoaderCheckPoint +from pydantic import BaseModel import torch import transformers -from models.loader import LoaderCheckPoint -class AnswerResult: +class ListenerToken: + """ + 观测结果 + """ + + input_ids: torch.LongTensor + _scores: torch.FloatTensor + + def __init__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor): + self.input_ids = input_ids + self._scores = _scores + + +class AnswerResult(BaseModel): """ 消息实体 """ @@ -18,6 +32,122 @@ class AnswerResult: llm_output: Optional[dict] = None +class AnswerResultStream: + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, answerResult: AnswerResult): + if self.callback_func is not None: + self.callback_func(answerResult) + + +class AnswerResultQueueSentinelTokenListenerQueue(transformers.StoppingCriteria): + """ + 定义模型stopping_criteria 监听者,在每次响应时将队列数据同步到AnswerResult + 实现此监听器的目的是,不同模型的预测输出可能不是矢量信息,hf框架可以自定义transformers.StoppingCriteria入参来接收每次预测的Tensor和损失函数, + 通过给 StoppingCriteriaList指定模型生成答案时停止的条件。每个 StoppingCriteria 对象表示一个停止条件 + 当每轮预测任务开始时,StoppingCriteria都会收到相同的预测结果,最终由下层实现类确认是否结束 + 输出值可用于 generatorAnswer generate_with_streaming的自定义参数观测,以实现更加精细的控制 + """ + + listenerQueue: deque = deque(maxlen=1) + + def __init__(self): + transformers.StoppingCriteria.__init__(self) + + def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor, **kwargs) -> bool: + """ + 每次响应时将数据添加到响应队列 + :param input_ids: + :param _scores: + :param kwargs: + :return: + """ + self.listenerQueue.append(ListenerToken(input_ids=input_ids, _scores=_scores)) + return False + + +class Iteratorize: + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + def __init__(self, func, kwargs={}): + self.mfunc = func + self.q = Queue() + self.sentinel = object() + self.kwargs = kwargs + self.stop_now = False + + def _callback(val): + """ + 模型输出预测结果收集 + 通过定义generate_with_callback收集器AnswerResultStream,收集模型预测的AnswerResult响应结果,最终由下层实现类确认是否结束 + 结束条件包含如下 + 1、模型预测结束、收集器self.q队列收到 self.sentinel标识 + 2、在处理迭代器队列消息时返回了break跳出迭代器,触发了StopIteration事件 + 3、模型预测出错 + 因为当前类是迭代器,所以在for in 中执行了break后 __exit__ 方法会被调用,最终stop_now属性会被更新,然后抛出异常结束预测行为 + 迭代器收集的行为如下 + 创建Iteratorize迭代对象, + 定义generate_with_callback收集器AnswerResultStream + 启动一个线程异步预测结果来调用上游checkpoint的实现方法_generate_answer + _generate_answer通过generate_with_callback定义的收集器,收集上游checkpoint包装的AnswerResult消息体 + 由于self.q是阻塞模式,每次预测后会被消费后才会执行下次预测 + 这时generate_with_callback会被阻塞 + 主线程Iteratorize对象的__next__方法调用获取阻塞消息并消费 + 1、消息为上游checkpoint包装的AnswerResult消息体,返回下游处理 + 2、消息为self.sentinel标识,抛出StopIteration异常 + 主线程Iteratorize对象__exit__收到消息,最终stop_now属性会被更新 + 异步线程检测stop_now属性被更新,抛出异常结束预测行为 + 迭代行为结束 + :param val: + :return: + """ + if self.stop_now: + raise ValueError + self.q.put(val) + + def gen(): + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + pass + except: + traceback.print_exc() + pass + + self.q.put(self.sentinel) + + self.thread = Thread(target=gen) + self.thread.start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True, None) + if obj is self.sentinel: + raise StopIteration + else: + return obj + + def __del__(self): + """ + 暂无实现 + :return: + """ + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ break 后会执行 """ + self.stop_now = True + + class BaseAnswer(ABC): """上层业务包装器.用于结果生成统一api调用""" @@ -25,17 +155,23 @@ class BaseAnswer(ABC): @abstractmethod def _check_point(self) -> LoaderCheckPoint: """Return _check_point of llm.""" + def generatorAnswer(self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None,) -> Generator[Any, str, bool]: + def generate_with_callback(callback=None, **kwargs): + kwargs['generate_with_callback'] = AnswerResultStream(callback_func=callback) + self._generate_answer(**kwargs) - @property - @abstractmethod - def _history_len(self) -> int: - """Return _history_len of llm.""" + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs) + + with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator: + for answerResult in generator: + yield answerResult @abstractmethod - def set_history_len(self, history_len: int) -> None: - """Return _history_len of llm.""" - - def generatorAnswer(self, prompt: str, - history: List[List[str]] = [], - streaming: bool = False): + def _generate_answer(self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + generate_with_callback: AnswerResultStream = None) -> None: pass diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index 7da423dc..0d19ee6d 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -1,83 +1,117 @@ from abc import ABC -from langchain.llms.base import LLM -from typing import Optional, List +from langchain.chains.base import Chain +from typing import Any, Dict, List, Optional, Generator +from langchain.callbacks.manager import CallbackManagerForChainRun +# from transformers.generation.logits_process import LogitsProcessor +# from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, - AnswerResult) + AnswerResult, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) +# import torch +import transformers -class ChatGLM(BaseAnswer, LLM, ABC): +class ChatGLMLLMChain(BaseAnswer, Chain, ABC): max_token: int = 10000 temperature: float = 0.01 - top_p = 0.9 + # 相关度 + top_p = 0.4 + # 候选词数量 + top_k = 10 checkPoint: LoaderCheckPoint = None # history = [] history_len: int = 10 + streaming_key: str = "streaming" #: :meta private: + history_key: str = "history" #: :meta private: + prompt_key: str = "prompt" #: :meta private: + output_key: str = "answer_result_stream" #: :meta private: def __init__(self, checkPoint: LoaderCheckPoint = None): super().__init__() self.checkPoint = checkPoint @property - def _llm_type(self) -> str: - return "ChatGLM" + def _chain_type(self) -> str: + return "ChatGLMLLMChain" @property def _check_point(self) -> LoaderCheckPoint: return self.checkPoint @property - def _history_len(self) -> int: - return self.history_len + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. - def set_history_len(self, history_len: int = 10) -> None: - self.history_len = history_len + :meta private: + """ + return [self.prompt_key] - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Generator]: + generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager) + return {self.output_key: generator} + + def _generate_answer(self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + generate_with_callback: AnswerResultStream = None) -> None: + history = inputs[self.history_key] + streaming = inputs[self.streaming_key] + prompt = inputs[self.prompt_key] print(f"__call:{prompt}") - response, _ = self.checkPoint.model.chat( - self.checkPoint.tokenizer, - prompt, - history=[], - max_length=self.max_token, - temperature=self.temperature - ) - print(f"response:{response}") - print(f"+++++++++++++++++++++++++++++++++++") - return response - - def generatorAnswer(self, prompt: str, - history: List[List[str]] = [], - streaming: bool = False): - + # Create the StoppingCriteriaList with the stopping strings + stopping_criteria_list = transformers.StoppingCriteriaList() + # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult + listenerQueue = AnswerResultQueueSentinelTokenListenerQueue() + stopping_criteria_list.append(listenerQueue) if streaming: history += [[]] for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat( self.checkPoint.tokenizer, prompt, - history=history[-self.history_len:-1] if self.history_len > 1 else [], + history=history[-self.history_len:-1] if self.history_len > 0 else [], max_length=self.max_token, - temperature=self.temperature + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + stopping_criteria=stopping_criteria_list )): # self.checkPoint.clear_torch_cache() history[-1] = [prompt, stream_resp] answer_result = AnswerResult() answer_result.history = history answer_result.llm_output = {"answer": stream_resp} - yield answer_result + generate_with_callback(answer_result) + self.checkPoint.clear_torch_cache() else: response, _ = self.checkPoint.model.chat( self.checkPoint.tokenizer, prompt, history=history[-self.history_len:] if self.history_len > 0 else [], max_length=self.max_token, - temperature=self.temperature + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + stopping_criteria=stopping_criteria_list ) self.checkPoint.clear_torch_cache() history += [[prompt, response]] answer_result = AnswerResult() answer_result.history = history answer_result.llm_output = {"answer": response} - yield answer_result + generate_with_callback(answer_result) diff --git a/models/fastchat_openai_llm.py b/models/fastchat_openai_llm.py index df66add6..217910ac 100644 --- a/models/fastchat_openai_llm.py +++ b/models/fastchat_openai_llm.py @@ -1,15 +1,37 @@ from abc import ABC -import requests -from typing import Optional, List -from langchain.llms.base import LLM +from langchain.chains.base import Chain +from typing import ( + Any, Dict, List, Optional, Generator, Collection, Set, + Callable, + Tuple, + Union) from models.loader import LoaderCheckPoint -from models.base import (RemoteRpcModel, - AnswerResult) -from typing import ( - Collection, - Dict +from langchain.callbacks.manager import CallbackManagerForChainRun +from models.base import (BaseAnswer, + RemoteRpcModel, + AnswerResult, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, ) +from pydantic import Extra, Field, root_validator + +from openai import ( + ChatCompletion +) + +import openai +import logging +import torch +import transformers + +logger = logging.getLogger(__name__) def _build_message_template() -> Dict[str, str]: @@ -22,34 +44,88 @@ def _build_message_template() -> Dict[str, str]: } -class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): +# 将历史对话数组转换为文本格式 +def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]: + build_messages: Collection[Dict[str, str]] = [] + + system_build_message = _build_message_template() + system_build_message['role'] = 'system' + system_build_message['content'] = "You are a helpful assistant." + build_messages.append(system_build_message) + if history: + for i, (user, assistant) in enumerate(history): + if user: + + user_build_message = _build_message_template() + user_build_message['role'] = 'user' + user_build_message['content'] = user + build_messages.append(user_build_message) + + if not assistant: + raise RuntimeError("历史数据结构不正确") + system_build_message = _build_message_template() + system_build_message['role'] = 'assistant' + system_build_message['content'] = assistant + build_messages.append(system_build_message) + + user_build_message = _build_message_template() + user_build_message['role'] = 'user' + user_build_message['content'] = query + build_messages.append(user_build_message) + return build_messages + + +class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): + client: Any + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" + max_retries: int = 6 api_base_url: str = "http://localhost:8000/v1" model_name: str = "chatglm-6b" max_token: int = 10000 temperature: float = 0.01 top_p = 0.9 checkPoint: LoaderCheckPoint = None - history = [] + # history = [] history_len: int = 10 + api_key: str = "" - def __init__(self, checkPoint: LoaderCheckPoint = None): + streaming_key: str = "streaming" #: :meta private: + history_key: str = "history" #: :meta private: + prompt_key: str = "prompt" #: :meta private: + output_key: str = "answer_result_stream" #: :meta private: + + def __init__(self, + checkPoint: LoaderCheckPoint = None, + # api_base_url:str="http://localhost:8000/v1", + # model_name:str="chatglm-6b", + # api_key:str="" + ): super().__init__() self.checkPoint = checkPoint @property - def _llm_type(self) -> str: - return "FastChat" + def _chain_type(self) -> str: + return "LLamaLLMChain" @property def _check_point(self) -> LoaderCheckPoint: return self.checkPoint @property - def _history_len(self) -> int: - return self.history_len + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. - def set_history_len(self, history_len: int = 10) -> None: - self.history_len = history_len + :meta private: + """ + return [self.prompt_key] + + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] @property def _api_key(self) -> str: @@ -60,7 +136,7 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): return self.api_base_url def set_api_key(self, api_key: str): - pass + self.api_key = api_key def set_api_base_url(self, api_base_url: str): self.api_base_url = api_base_url @@ -68,70 +144,116 @@ class FastChatOpenAILLM(RemoteRpcModel, LLM, ABC): def call_model_name(self, model_name): self.model_name = model_name - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _create_retry_decorator(self) -> Callable[[Any], Any]: + min_seconds = 1 + max_seconds = 60 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + def completion_with_retry(self, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = self._create_retry_decorator() + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return self.client.create(**kwargs) + + return _completion_with_retry(**kwargs) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Generator]: + generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager) + return {self.output_key: generator} + + def _generate_answer(self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + generate_with_callback: AnswerResultStream = None) -> None: + + history = inputs.get(self.history_key, []) + streaming = inputs.get(self.streaming_key, False) + prompt = inputs[self.prompt_key] + stop = inputs.get("stop", "stop") print(f"__call:{prompt}") try: - import openai + # Not support yet - openai.api_key = "EMPTY" + # openai.api_key = "EMPTY" + openai.api_key = self.api_key openai.api_base = self.api_base_url - except ImportError: + self.client = openai.ChatCompletion + except AttributeError: raise ValueError( - "Could not import openai python package. " - "Please install it with `pip install openai`." + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." ) - # create a chat completion - completion = openai.ChatCompletion.create( - model=self.model_name, - messages=self.build_message_list(prompt) - ) - print(f"response:{completion.choices[0].message.content}") - print(f"+++++++++++++++++++++++++++++++++++") - return completion.choices[0].message.content + msg = build_message_list(prompt, history=history) - # 将历史对话数组转换为文本格式 - def build_message_list(self, query) -> Collection[Dict[str, str]]: - build_message_list: Collection[Dict[str, str]] = [] - history = self.history[-self.history_len:] if self.history_len > 0 else [] - for i, (old_query, response) in enumerate(history): - user_build_message = _build_message_template() - user_build_message['role'] = 'user' - user_build_message['content'] = old_query - system_build_message = _build_message_template() - system_build_message['role'] = 'system' - system_build_message['content'] = response - build_message_list.append(user_build_message) - build_message_list.append(system_build_message) + if streaming: + params = {"stream": streaming, + "model": self.model_name, + "stop": stop} + out_str = "" + for stream_resp in self.completion_with_retry( + messages=msg, + **params + ): + role = stream_resp["choices"][0]["delta"].get("role", "") + token = stream_resp["choices"][0]["delta"].get("content", "") + out_str += token + history[-1] = [prompt, out_str] + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": out_str} + generate_with_callback(answer_result) + else: - user_build_message = _build_message_template() - user_build_message['role'] = 'user' - user_build_message['content'] = query - build_message_list.append(user_build_message) - return build_message_list - - def generatorAnswer(self, prompt: str, - history: List[List[str]] = [], - streaming: bool = False): - - try: - import openai - # Not support yet - openai.api_key = "EMPTY" - openai.api_base = self.api_base_url - except ImportError: - raise ValueError( - "Could not import openai python package. " - "Please install it with `pip install openai`." + params = {"stream": streaming, + "model": self.model_name, + "stop": stop} + response = self.completion_with_retry( + messages=msg, + **params ) - # create a chat completion - completion = openai.ChatCompletion.create( - model=self.model_name, - messages=self.build_message_list(prompt) - ) + role = response["choices"][0]["message"].get("role", "") + content = response["choices"][0]["message"].get("content", "") + history += [[prompt, content]] + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": content} + generate_with_callback(answer_result) - history += [[prompt, completion.choices[0].message.content]] - answer_result = AnswerResult() - answer_result.history = history - answer_result.llm_output = {"answer": completion.choices[0].message.content} - yield answer_result +if __name__ == "__main__": + + chain = FastChatOpenAILLMChain() + + chain.set_api_key("EMPTY") + # chain.set_api_base_url("https://api.openai.com/v1") + # chain.call_model_name("gpt-3.5-turbo") + + answer_result_stream_result = chain({"streaming": True, + "prompt": "你好", + "history": [] + }) + + for answer_result in answer_result_stream_result['answer_result_stream']: + resp = answer_result.llm_output["answer"] + print(resp) diff --git a/models/llama_llm.py b/models/llama_llm.py index 69fde56b..014fd81d 100644 --- a/models/llama_llm.py +++ b/models/llama_llm.py @@ -1,26 +1,32 @@ -from abc import ABC -from langchain.llms.base import LLM -import random -import torch -import transformers +from abc import ABC +from langchain.chains.base import Chain +from typing import Any, Dict, List, Optional, Generator, Union +from langchain.callbacks.manager import CallbackManagerForChainRun from transformers.generation.logits_process import LogitsProcessor from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList -from typing import Optional, List, Dict, Any from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, - AnswerResult) + AnswerResult, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) +import torch +import transformers class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: Union[torch.LongTensor, list], + scores: Union[torch.FloatTensor, list]) -> torch.FloatTensor: + # llama-cpp模型返回的是list,为兼容性考虑,需要判断input_ids和scores的类型,将list转换为torch.Tensor + input_ids = torch.tensor(input_ids) if isinstance(input_ids, list) else input_ids + scores = torch.tensor(scores) if isinstance(scores, list) else scores if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() scores[..., 5] = 5e4 return scores -class LLamaLLM(BaseAnswer, LLM, ABC): +class LLamaLLMChain(BaseAnswer, Chain, ABC): checkPoint: LoaderCheckPoint = None # history = [] history_len: int = 3 @@ -34,32 +40,34 @@ class LLamaLLM(BaseAnswer, LLM, ABC): min_length: int = 0 logits_processor: LogitsProcessorList = None stopping_criteria: Optional[StoppingCriteriaList] = None - eos_token_id: Optional[int] = [2] - - state: object = {'max_new_tokens': 50, - 'seed': 1, - 'temperature': 0, 'top_p': 0.1, - 'top_k': 40, 'typical_p': 1, - 'repetition_penalty': 1.2, - 'encoder_repetition_penalty': 1, - 'no_repeat_ngram_size': 0, - 'min_length': 0, - 'penalty_alpha': 0, - 'num_beams': 1, - 'length_penalty': 1, - 'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False, - 'truncation_length': 2048, 'custom_stopping_strings': '', - 'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False, - 'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None', - 'pre_layer': 0, 'gpu_memory_0': 0} + streaming_key: str = "streaming" #: :meta private: + history_key: str = "history" #: :meta private: + prompt_key: str = "prompt" #: :meta private: + output_key: str = "answer_result_stream" #: :meta private: def __init__(self, checkPoint: LoaderCheckPoint = None): super().__init__() self.checkPoint = checkPoint @property - def _llm_type(self) -> str: - return "LLamaLLM" + def _chain_type(self) -> str: + return "LLamaLLMChain" + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. + + :meta private: + """ + return [self.prompt_key] + + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] @property def _check_point(self) -> LoaderCheckPoint: @@ -104,35 +112,31 @@ class LLamaLLM(BaseAnswer, LLM, ABC): formatted_history += "### Human:{}\n### Assistant:".format(query) return formatted_history - def prepare_inputs_for_generation(self, - input_ids: torch.LongTensor): - """ - 预生成注意力掩码和 输入序列中每个位置的索引的张量 - # TODO 没有思路 - :return: - """ + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Generator]: + generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager) + return {self.output_key: generator} - mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device) + def _generate_answer(self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + generate_with_callback: AnswerResultStream = None) -> None: - attention_mask = self.get_masks(input_ids, input_ids.device) - - position_ids = self.get_position_ids( - input_ids, - device=input_ids.device, - mask_positions=mask_positions - ) - - return input_ids, position_ids, attention_mask - - @property - def _history_len(self) -> int: - return self.history_len - - def set_history_len(self, history_len: int = 10) -> None: - self.history_len = history_len - - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + history = inputs[self.history_key] + streaming = inputs[self.streaming_key] + prompt = inputs[self.prompt_key] print(f"__call:{prompt}") + + # Create the StoppingCriteriaList with the stopping strings + self.stopping_criteria = transformers.StoppingCriteriaList() + # 定义模型stopping_criteria 队列,在每次响应时将 torch.LongTensor, torch.FloatTensor同步到AnswerResult + listenerQueue = AnswerResultQueueSentinelTokenListenerQueue() + self.stopping_criteria.append(listenerQueue) + # TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现 + soft_prompt = self.history_to_text(query=prompt, history=history) if self.logits_processor is None: self.logits_processor = LogitsProcessorList() self.logits_processor.append(InvalidScoreLogitsProcessor()) @@ -151,35 +155,36 @@ class LLamaLLM(BaseAnswer, LLM, ABC): "logits_processor": self.logits_processor} # 向量转换 - input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens) - # input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids) - + input_ids = self.encode(soft_prompt, add_bos_token=self.checkPoint.tokenizer.add_bos_token, + truncation_length=self.max_new_tokens) gen_kwargs.update({'inputs': input_ids}) - # 注意力掩码 - # gen_kwargs.update({'attention_mask': attention_mask}) - # gen_kwargs.update({'position_ids': position_ids}) - if self.stopping_criteria is None: - self.stopping_criteria = transformers.StoppingCriteriaList() # 观测输出 gen_kwargs.update({'stopping_criteria': self.stopping_criteria}) + # llama-cpp模型的参数与transformers的参数字段有较大差异,直接调用会返回不支持的字段错误 + # 因此需要先判断模型是否是llama-cpp模型,然后取gen_kwargs与模型generate方法字段的交集 + # 仅将交集字段传给模型以保证兼容性 + # todo llama-cpp模型在本框架下兼容性较差,后续可以考虑重写一个llama_cpp_llm.py模块 + if "llama_cpp" in self.checkPoint.model.__str__(): + import inspect - output_ids = self.checkPoint.model.generate(**gen_kwargs) + common_kwargs_keys = set(inspect.getfullargspec(self.checkPoint.model.generate).args) & set( + gen_kwargs.keys()) + common_kwargs = {key: gen_kwargs[key] for key in common_kwargs_keys} + # ? llama-cpp模型的generate方法似乎只接受.cpu类型的输入,响应很慢,慢到哭泣 + # ?为什么会不支持GPU呢,不应该啊? + output_ids = torch.tensor( + [list(self.checkPoint.model.generate(input_id_i.cpu(), **common_kwargs)) for input_id_i in input_ids]) + + else: + output_ids = self.checkPoint.model.generate(**gen_kwargs) new_tokens = len(output_ids[0]) - len(input_ids[0]) reply = self.decode(output_ids[0][-new_tokens:]) print(f"response:{reply}") print(f"+++++++++++++++++++++++++++++++++++") - return reply - - def generatorAnswer(self, prompt: str, - history: List[List[str]] = [], - streaming: bool = False): - - # TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现 - softprompt = self.history_to_text(prompt,history=history) - response = self._call(prompt=softprompt, stop=['\n###']) answer_result = AnswerResult() - answer_result.history = history + [[prompt, response]] - answer_result.llm_output = {"answer": response} - yield answer_result + history += [[prompt, reply]] + answer_result.history = history + answer_result.llm_output = {"answer": reply} + generate_with_callback(answer_result) diff --git a/models/loader/args.py b/models/loader/args.py index b15ad5e4..cd3e78b8 100644 --- a/models/loader/args.py +++ b/models/loader/args.py @@ -1,3 +1,4 @@ + import argparse import os from configs.model_config import * @@ -43,7 +44,8 @@ parser.add_argument('--no-remote-model', action='store_true', help='remote in th parser.add_argument('--model-name', type=str, default=LLM_MODEL, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras") - +parser.add_argument('--use-ptuning-v2',action='store_true',help="whether use ptuning-v2 checkpoint") +parser.add_argument("--ptuning-dir",type=str,default=PTUNING_DIR,help="the dir of ptuning-v2 checkpoint") # Accelerate/transformers parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT, help='Load the model with 8-bit precision.') diff --git a/models/loader/loader.py b/models/loader/loader.py index f315e6ce..f43bb1d2 100644 --- a/models/loader/loader.py +++ b/models/loader/loader.py @@ -20,6 +20,7 @@ class LoaderCheckPoint: no_remote_model: bool = False # 模型名称 model_name: str = None + pretrained_model_name: str = None tokenizer: object = None # 模型全路径 model_path: str = None @@ -35,11 +36,11 @@ class LoaderCheckPoint: # 因此主要的解决思路是清理环境变量里PATH下的不匹配的cuda版本,一劳永逸的方法是: # 0. 在终端执行`pip uninstall bitsandbytes` # 1. 删除.bashrc文件下关于PATH的条目 - # 2. 在终端执行 `echo $PATH >> .bashrc` + # 2. 在终端执行 `echo $PATH >> .bashrc` # 3. 删除.bashrc文件下PATH中关于不匹配的cuda版本路径 # 4. 在终端执行`source .bashrc` # 5. 再执行`pip install bitsandbytes` - + load_in_8bit: bool = False is_llamacpp: bool = False bf16: bool = False @@ -67,43 +68,49 @@ class LoaderCheckPoint: self.load_in_8bit = params.get('load_in_8bit', False) self.bf16 = params.get('bf16', False) - def _load_model_config(self, model_name): + def _load_model_config(self): if self.model_path: + self.model_path = re.sub("\s", "", self.model_path) checkpoint = Path(f'{self.model_path}') else: - if not self.no_remote_model: - checkpoint = model_name - else: + if self.no_remote_model: raise ValueError( "本地模型local_model_path未配置路径" ) + else: + checkpoint = self.pretrained_model_name - model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) + print(f"load_model_config {checkpoint}...") + try: - return model_config + model_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) + return model_config + except Exception as e: + print(e) + return checkpoint - def _load_model(self, model_name): + def _load_model(self): """ 加载自定义位置的model - :param model_name: :return: """ - print(f"Loading {model_name}...") t0 = time.time() if self.model_path: + self.model_path = re.sub("\s", "", self.model_path) checkpoint = Path(f'{self.model_path}') else: - if not self.no_remote_model: - checkpoint = model_name - else: + if self.no_remote_model: raise ValueError( "本地模型local_model_path未配置路径" ) + else: + checkpoint = self.pretrained_model_name + print(f"Loading {checkpoint}...") self.is_llamacpp = len(list(Path(f'{checkpoint}').glob('ggml*.bin'))) > 0 - if 'chatglm' in model_name.lower(): + if 'chatglm' in self.model_name.lower() or "chatyuan" in self.model_name.lower(): LoaderClass = AutoModel else: LoaderClass = AutoModelForCausalLM @@ -126,8 +133,14 @@ class LoaderCheckPoint: .half() .cuda() ) + # 支持自定义cuda设备 + elif ":" in self.llm_device: + model = LoaderClass.from_pretrained(checkpoint, + config=self.model_config, + torch_dtype=torch.bfloat16 if self.bf16 else torch.float16, + trust_remote_code=True).half().to(self.llm_device) else: - from accelerate import dispatch_model + from accelerate import dispatch_model, infer_auto_device_map model = LoaderClass.from_pretrained(checkpoint, config=self.model_config, @@ -135,12 +148,22 @@ class LoaderCheckPoint: trust_remote_code=True).half() # 可传入device_map自定义每张卡的部署情况 if self.device_map is None: - if 'chatglm' in model_name.lower(): + if 'chatglm' in self.model_name.lower() and not "chatglm2" in self.model_name.lower(): self.device_map = self.chatglm_auto_configure_device_map(num_gpus) - elif 'moss' in model_name.lower(): - self.device_map = self.moss_auto_configure_device_map(num_gpus, model_name) + elif 'moss' in self.model_name.lower(): + self.device_map = self.moss_auto_configure_device_map(num_gpus, checkpoint) else: - self.device_map = self.chatglm_auto_configure_device_map(num_gpus) + # 基于如下方式作为默认的多卡加载方案针对新模型基本不会失败 + # 在chatglm2-6b,bloom-3b,blooz-7b1上进行了测试,GPU负载也相对均衡 + from accelerate.utils import get_balanced_memory + max_memory = get_balanced_memory(model, + dtype=torch.int8 if self.load_in_8bit else None, + low_zero=False, + no_split_module_classes=model._no_split_modules) + self.device_map = infer_auto_device_map(model, + dtype=torch.float16 if not self.load_in_8bit else torch.int8, + max_memory=max_memory, + no_split_module_classes=model._no_split_modules) model = dispatch_model(model, device_map=self.device_map) else: @@ -156,7 +179,7 @@ class LoaderCheckPoint: elif self.is_llamacpp: try: - from models.extensions.llamacpp_model_alternative import LlamaCppModel + from llama_cpp import Llama except ImportError as exc: raise ValueError( @@ -167,7 +190,16 @@ class LoaderCheckPoint: model_file = list(checkpoint.glob('ggml*.bin'))[0] print(f"llama.cpp weights detected: {model_file}\n") - model, tokenizer = LlamaCppModel.from_pretrained(model_file) + model = Llama(model_path=model_file._str) + + # 实测llama-cpp-vicuna13b-q5_1的AutoTokenizer加载tokenizer的速度极慢,应存在优化空间 + # 但需要对huggingface的AutoTokenizer进行优化 + + # tokenizer = model.tokenizer + # todo 此处调用AutoTokenizer的tokenizer,但后续可以测试自带tokenizer是不是兼容 + # * -> 自带的tokenizer不与transoformers的tokenizer兼容,无法使用 + + tokenizer = AutoTokenizer.from_pretrained(self.model_name) return model, tokenizer elif self.load_in_8bit: @@ -194,7 +226,7 @@ class LoaderCheckPoint: llm_int8_enable_fp32_cpu_offload=False) with init_empty_weights(): - model = LoaderClass.from_config(self.model_config,trust_remote_code = True) + model = LoaderClass.from_config(self.model_config, trust_remote_code=True) model.tie_weights() if self.device_map is not None: params['device_map'] = self.device_map @@ -257,10 +289,21 @@ class LoaderCheckPoint: # 在调用chat或者stream_chat时,input_ids会被放到model.device上 # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 - device_map = {f'{layer_prefix}.word_embeddings': 0, - f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0, - f'base_model.model.lm_head': 0, } + encode = "" + if 'chatglm2' in self.model_name: + device_map = { + f"{layer_prefix}.embedding.word_embeddings": 0, + f"{layer_prefix}.rotary_pos_emb": 0, + f"{layer_prefix}.output_layer": 0, + f"{layer_prefix}.encoder.final_layernorm": 0, + f"base_model.model.output_layer": 0 + } + encode = ".encoder" + else: + device_map = {f'{layer_prefix}.word_embeddings': 0, + f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0, + f'base_model.model.lm_head': 0, } used = 2 gpu_target = 0 for i in range(num_trans_layers): @@ -268,12 +311,12 @@ class LoaderCheckPoint: gpu_target += 1 used = 0 assert gpu_target < num_gpus - device_map[f'{layer_prefix}.layers.{i}'] = gpu_target + device_map[f'{layer_prefix}{encode}.layers.{i}'] = gpu_target used += 1 return device_map - def moss_auto_configure_device_map(self, num_gpus: int, model_name) -> Dict[str, int]: + def moss_auto_configure_device_map(self, num_gpus: int, checkpoint) -> Dict[str, int]: try: from accelerate import init_empty_weights @@ -288,16 +331,6 @@ class LoaderCheckPoint: "`pip install bitsandbytes``pip install accelerate`." ) from exc - if self.model_path: - checkpoint = Path(f'{self.model_path}') - else: - if not self.no_remote_model: - checkpoint = model_name - else: - raise ValueError( - "本地模型local_model_path未配置路径" - ) - cls = get_class_from_dynamic_module(class_reference="fnlp/moss-moon-003-sft--modeling_moss.MossForCausalLM", pretrained_model_name_or_path=checkpoint) @@ -385,7 +418,7 @@ class LoaderCheckPoint: print( "如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") elif torch.has_cuda: - device_id = "0" if torch.cuda.is_available() else None + device_id = "0" if torch.cuda.is_available() and (":" not in self.llm_device) else None CUDA_DEVICE = f"{self.llm_device}:{device_id}" if device_id else self.llm_device with torch.cuda.device(CUDA_DEVICE): torch.cuda.empty_cache() @@ -404,33 +437,37 @@ class LoaderCheckPoint: def reload_model(self): self.unload_model() - self.model_config = self._load_model_config(self.model_name) + self.model_config = self._load_model_config() if self.use_ptuning_v2: try: - prefix_encoder_file = open(Path(f'{self.ptuning_dir}/config.json'), 'r') + prefix_encoder_file = open(Path(f'{os.path.abspath(self.ptuning_dir)}/config.json'), 'r') prefix_encoder_config = json.loads(prefix_encoder_file.read()) prefix_encoder_file.close() self.model_config.pre_seq_len = prefix_encoder_config['pre_seq_len'] self.model_config.prefix_projection = prefix_encoder_config['prefix_projection'] except Exception as e: + print(e) print("加载PrefixEncoder config.json失败") - self.model, self.tokenizer = self._load_model(self.model_name) + self.model, self.tokenizer = self._load_model() if self.lora: self._add_lora_to_model([self.lora]) if self.use_ptuning_v2: try: - prefix_state_dict = torch.load(Path(f'{self.ptuning_dir}/pytorch_model.bin')) + prefix_state_dict = torch.load(Path(f'{os.path.abspath(self.ptuning_dir)}/pytorch_model.bin')) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.float() + print("加载ptuning检查点成功!") except Exception as e: + print(e) print("加载PrefixEncoder模型参数失败") - - self.model = self.model.eval() + # llama-cpp模型(至少vicuna-13b)的eval方法就是自身,其没有eval方法 + if not self.is_llamacpp: + self.model = self.model.eval() diff --git a/models/moss_llm.py b/models/moss_llm.py index 80a86877..f6b112d9 100644 --- a/models/moss_llm.py +++ b/models/moss_llm.py @@ -1,12 +1,20 @@ from abc import ABC -from langchain.llms.base import LLM -from typing import Optional, List +from langchain.chains.base import Chain +from typing import Any, Dict, List, Optional, Generator, Union +from langchain.callbacks.manager import CallbackManagerForChainRun +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from models.loader import LoaderCheckPoint from models.base import (BaseAnswer, - AnswerResult) + AnswerResult, + AnswerResultStream, + AnswerResultQueueSentinelTokenListenerQueue) +import torch +import transformers import torch +# todo 建议重写instruction,在该instruction下,各模型的表现比较差 META_INSTRUCTION = \ """You are an AI assistant whose name is MOSS. - MOSS is a conversational language model that is developed by Fudan University. It is designed to be helpful, honest, and harmless. @@ -21,49 +29,76 @@ META_INSTRUCTION = \ """ -class MOSSLLM(BaseAnswer, LLM, ABC): +# todo 在MOSSLLM类下,各模型的响应速度很慢,后续要检查一下原因 +class MOSSLLMChain(BaseAnswer, Chain, ABC): max_token: int = 2048 temperature: float = 0.7 top_p = 0.8 # history = [] checkPoint: LoaderCheckPoint = None history_len: int = 10 + streaming_key: str = "streaming" #: :meta private: + history_key: str = "history" #: :meta private: + prompt_key: str = "prompt" #: :meta private: + output_key: str = "answer_result_stream" #: :meta private: def __init__(self, checkPoint: LoaderCheckPoint = None): super().__init__() self.checkPoint = checkPoint @property - def _llm_type(self) -> str: - return "MOSS" + def _chain_type(self) -> str: + return "MOSSLLMChain" + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. + + :meta private: + """ + return [self.prompt_key] + + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] @property def _check_point(self) -> LoaderCheckPoint: return self.checkPoint - @property - def set_history_len(self) -> int: - return self.history_len + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Generator]: + generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager) + return {self.output_key: generator} - def _set_history_len(self, history_len: int) -> None: - self.history_len = history_len + def _generate_answer(self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + generate_with_callback: AnswerResultStream = None) -> None: - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - pass - - def generatorAnswer(self, prompt: str, - history: List[List[str]] = [], - streaming: bool = False): + history = inputs[self.history_key] + streaming = inputs[self.streaming_key] + prompt = inputs[self.prompt_key] + print(f"__call:{prompt}") if len(history) > 0: history = history[-self.history_len:] if self.history_len > 0 else [] prompt_w_history = str(history) prompt_w_history += '<|Human|>: ' + prompt + '' else: - prompt_w_history = META_INSTRUCTION + prompt_w_history = META_INSTRUCTION.replace("MOSS", self.checkPoint.model_name.split("/")[-1]) prompt_w_history += '<|Human|>: ' + prompt + '' inputs = self.checkPoint.tokenizer(prompt_w_history, return_tensors="pt") with torch.no_grad(): + # max_length似乎可以设的小一些,而repetion_penalty应大一些,否则chatyuan,bloom等模型为满足max会重复输出 + # outputs = self.checkPoint.model.generate( inputs.input_ids.cuda(), attention_mask=inputs.attention_mask.cuda(), @@ -76,13 +111,12 @@ class MOSSLLM(BaseAnswer, LLM, ABC): num_return_sequences=1, eos_token_id=106068, pad_token_id=self.checkPoint.tokenizer.pad_token_id) - response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + response = self.checkPoint.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], + skip_special_tokens=True) self.checkPoint.clear_torch_cache() history += [[prompt, response]] answer_result = AnswerResult() answer_result.history = history answer_result.llm_output = {"answer": response} - yield answer_result - - + generate_with_callback(answer_result) diff --git a/models/shared.py b/models/shared.py index 8a76edb5..3ccf2502 100644 --- a/models/shared.py +++ b/models/shared.py @@ -24,13 +24,12 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ if use_ptuning_v2: loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2 + # 如果指定了参数,则使用参数的配置 if llm_model: llm_model_info = llm_model_dict[llm_model] - if loaderCheckPoint.no_remote_model: - loaderCheckPoint.model_name = llm_model_info['name'] - else: - loaderCheckPoint.model_name = llm_model_info['pretrained_model_name'] + loaderCheckPoint.model_name = llm_model_info['name'] + loaderCheckPoint.pretrained_model_name = llm_model_info['pretrained_model_name'] loaderCheckPoint.model_path = llm_model_info["local_model_path"] @@ -44,4 +43,5 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_ if 'FastChatOpenAILLM' in llm_model_info["provides"]: modelInsLLM.set_api_base_url(llm_model_info['api_base_url']) modelInsLLM.call_model_name(llm_model_info['name']) + modelInsLLM.set_api_key(llm_model_info['api_key']) return modelInsLLM diff --git a/requirements.txt b/requirements.txt index 9f962dd5..7f97f67f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ beautifulsoup4 icetk cpm_kernels faiss-cpu -gradio==3.28.3 +gradio==3.37.0 fastapi~=0.95.0 uvicorn~=0.21.1 pypinyin~=0.48.0 @@ -23,9 +23,13 @@ openai #accelerate~=0.18.0 #peft~=0.3.0 #bitsandbytes; platform_system != "Windows" -#llama-cpp-python==0.1.34; platform_system != "Windows" -#https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows" +# 要调用llama-cpp模型,如vicuma-13b量化模型需要安装llama-cpp-python库 +# but!!! 实测pip install 不好使,需要手动从ttps://github.com/abetlen/llama-cpp-python/releases/下载 +# 而且注意不同时期的ggml格式并不!兼!容!!!因此需要安装的llama-cpp-python版本也不一致,需要手动测试才能确定 +# 实测ggml-vicuna-13b-1.1在llama-cpp-python 0.1.63上可正常兼容 +# 不过!!!本项目模型加载的方式控制的比较严格,与llama-cpp-python的兼容性较差,很多参数设定不能使用, +# 建议如非必要还是不要使用llama-cpp torch~=2.0.0 pydantic~=1.10.7 starlette~=0.26.1 diff --git a/test/models/test_fastchat_openai_llm.py b/test/models/test_fastchat_openai_llm.py deleted file mode 100644 index a0312be3..00000000 --- a/test/models/test_fastchat_openai_llm.py +++ /dev/null @@ -1,39 +0,0 @@ -import sys -import os - -sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../') -import asyncio -from argparse import Namespace -from models.loader.args import parser -from models.loader import LoaderCheckPoint - - -import models.shared as shared - - - -async def dispatch(args: Namespace): - args_dict = vars(args) - - shared.loaderCheckPoint = LoaderCheckPoint(args_dict) - - llm_model_ins = shared.loaderLLM() - - history = [ - ("which city is this?", "tokyo"), - ("why?", "she's japanese"), - - ] - for answer_result in llm_model_ins.generatorAnswer(prompt="你好? ", history=history, - streaming=False): - resp = answer_result.llm_output["answer"] - - print(resp) - -if __name__ == '__main__': - args = None - args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'fastchat-chatglm-6b', '--no-remote-model']) - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(dispatch(args)) diff --git a/views/src/api/chat.ts b/views/src/api/chat.ts index 7be86c10..4d834ed3 100644 --- a/views/src/api/chat.ts +++ b/views/src/api/chat.ts @@ -16,6 +16,24 @@ export const chatfile = (params: any) => { }) } +export const getKbsList = () => { + return api({ + url: '/local_doc_qa/list_knowledge_base', + method: 'get', + + }) +} + +export const deleteKb = (knowledge_base_id: any) => { + return api({ + url: '/local_doc_qa/delete_knowledge_base', + method: 'delete', + params: { + knowledge_base_id, + }, + }) +} + export const getfilelist = (knowledge_base_id: any) => { return api({ url: '/local_doc_qa/list_files', @@ -35,8 +53,8 @@ export const bing_search = (params: any) => { export const deletefile = (params: any) => { return api({ url: '/local_doc_qa/delete_file', - method: 'post', - data: JSON.stringify(params), + method: 'delete', + params, }) } export const web_url = () => { diff --git a/views/src/views/chat/index.vue b/views/src/views/chat/index.vue index b9eac9b9..814005cc 100644 --- a/views/src/views/chat/index.vue +++ b/views/src/views/chat/index.vue @@ -555,7 +555,7 @@ const options = computed(() => { return common }) -function handleSelect(key: 'copyText' | 'delete' | 'toggleRenderType') { +function handleSelect(key: string) { if (key == '清除会话') { handleClear() } @@ -658,7 +658,6 @@ function searchfun() { diff --git a/views/src/views/chat/layout/sider/knowledge-base/index.vue b/views/src/views/chat/layout/sider/knowledge-base/index.vue index 180efd41..64387676 100644 --- a/views/src/views/chat/layout/sider/knowledge-base/index.vue +++ b/views/src/views/chat/layout/sider/knowledge-base/index.vue @@ -3,15 +3,16 @@ import { NButton, NForm, NFormItem, NInput, NPopconfirm } from 'naive-ui' import { onMounted, ref } from 'vue' import filelist from './filelist.vue' import { SvgIcon } from '@/components/common' -import { deletekb, getkblist } from '@/api/chat' +import { deleteKb, getKbsList } from '@/api/chat' import { idStore } from '@/store/modules/knowledgebaseid/id' + const items = ref([]) const choice = ref('') const store = idStore() onMounted(async () => { choice.value = store.knowledgeid - const res = await getkblist({}) + const res = await getKbsList() res.data.data.forEach((item: any) => { items.value.push({ value: item, @@ -52,8 +53,8 @@ const handleClick = () => { } } async function handleDelete(item: any) { - await deletekb(item.value) - const res = await getkblist({}) + await deleteKb(item.value) + const res = await getKbsList() items.value = [] res.data.data.forEach((item: any) => { items.value.push({ diff --git a/webui.py b/webui.py index 0a96e4c5..c7e78803 100644 --- a/webui.py +++ b/webui.py @@ -85,8 +85,11 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR yield history + [[query, "请选择知识库后进行测试,当前未选择知识库。"]], "" else: - for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history, - streaming=streaming): + + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": query, "history": history, "streaming": streaming}) + + for answer_result in answer_result_stream_result['answer_result_stream']: resp = answer_result.llm_output["answer"] history = answer_result.history history[-1][-1] = resp @@ -101,11 +104,13 @@ def init_model(): args_dict = vars(args) shared.loaderCheckPoint = LoaderCheckPoint(args_dict) llm_model_ins = shared.loaderLLM() - llm_model_ins.set_history_len(LLM_HISTORY_LEN) + llm_model_ins.history_len = LLM_HISTORY_LEN try: local_doc_qa.init_cfg(llm_model=llm_model_ins) - generator = local_doc_qa.llm.generatorAnswer("你好") - for answer_result in generator: + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": "你好", "history": [], "streaming": False}) + + for answer_result in answer_result_stream_result['answer_result_stream']: print(answer_result.llm_output) reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" logger.info(reply) @@ -141,7 +146,7 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") filelist = [] - if local_doc_qa.llm and local_doc_qa.embeddings: + if local_doc_qa.llm_model_chain and local_doc_qa.embeddings: if isinstance(files, list): for file in files: filename = os.path.split(file.name)[-1] @@ -165,8 +170,8 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte def change_vs_name_input(vs_id, history): if vs_id == "新建知识库": - return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history,\ - gr.update(choices=[]), gr.update(visible=False) + return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history, \ + gr.update(choices=[]), gr.update(visible=False) else: vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") if "index.faiss" in os.listdir(vs_path): @@ -218,7 +223,12 @@ def change_chunk_conent(mode, label_conent, history): def add_vs_name(vs_name, chatbot): - if vs_name in get_vs_list(): + if vs_name is None or vs_name.strip() == "": + vs_status = "知识库名称不能为空,请重新填写知识库名称" + chatbot = chatbot + [[None, vs_status]] + return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update( + visible=False), chatbot, gr.update(visible=False) + elif vs_name in get_vs_list(): vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交" chatbot = chatbot + [[None, vs_status]] return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update( @@ -257,6 +267,7 @@ def reinit_vector_store(vs_id, history): def refresh_vs_list(): return gr.update(choices=get_vs_list()), gr.update(choices=get_vs_list()) + def delete_file(vs_id, files_to_delete, chatbot): vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") content_path = os.path.join(KB_ROOT_PATH, vs_id, "content") @@ -270,11 +281,11 @@ def delete_file(vs_id, files_to_delete, chatbot): rested_files = local_doc_qa.list_file_from_vector_store(vs_path) if "fail" in status: vs_status = "文件删除失败。" - elif len(rested_files)>0: + elif len(rested_files) > 0: vs_status = "文件删除成功。" else: vs_status = f"文件删除成功,知识库{vs_id}中无已上传文件,请先上传文件后,再开始提问。" - logger.info(",".join(files_to_delete)+vs_status) + logger.info(",".join(files_to_delete) + vs_status) chatbot = chatbot + [[None, vs_status]] return gr.update(choices=local_doc_qa.list_file_from_vector_store(vs_path), value=[]), chatbot @@ -285,7 +296,8 @@ def delete_vs(vs_id, chatbot): status = f"成功删除知识库{vs_id}" logger.info(status) chatbot = chatbot + [[None, status]] - return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update(visible=True), \ + return gr.update(choices=get_vs_list(), value=get_vs_list()[0]), gr.update(visible=True), gr.update( + visible=True), \ gr.update(visible=False), chatbot, gr.update(visible=False) except Exception as e: logger.error(e) @@ -328,7 +340,8 @@ default_theme_args = dict( with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo: vs_path, file_status, model_status = gr.State( - os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State( + os.path.join(KB_ROOT_PATH, get_vs_list()[0], "vector_store") if len(get_vs_list()) > 1 else ""), gr.State( + ""), gr.State( model_status) gr.Markdown(webui_title) with gr.Tab("对话"): @@ -381,8 +394,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as load_folder_button = gr.Button("上传文件夹并加载知识库") with gr.Tab("删除文件"): files_to_delete = gr.CheckboxGroup(choices=[], - label="请从知识库已有文件中选择要删除的文件", - interactive=True) + label="请从知识库已有文件中选择要删除的文件", + interactive=True) delete_file_button = gr.Button("从知识库中删除选中文件") vs_refresh.click(fn=refresh_vs_list, inputs=[], @@ -450,9 +463,9 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as with vs_setting: vs_refresh = gr.Button("更新已有知识库选项") select_vs_test = gr.Dropdown(get_vs_list(), - label="请选择要加载的知识库", - interactive=True, - value=get_vs_list()[0] if len(get_vs_list()) > 0 else None) + label="请选择要加载的知识库", + interactive=True, + value=get_vs_list()[0] if len(get_vs_list()) > 0 else None) vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文", lines=1, interactive=True, @@ -492,8 +505,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as inputs=[vs_name, chatbot], outputs=[select_vs_test, vs_name, vs_add, file2vs, chatbot]) select_vs_test.change(fn=change_vs_name_input, - inputs=[select_vs_test, chatbot], - outputs=[vs_name, vs_add, file2vs, vs_path, chatbot]) + inputs=[select_vs_test, chatbot], + outputs=[vs_name, vs_add, file2vs, vs_path, chatbot]) load_file_button.click(get_vector_store, show_progress=True, inputs=[select_vs_test, files, sentence_size, chatbot, vs_add, vs_add], diff --git a/webui_st.py b/webui_st.py index 6d1265e9..1584a55a 100644 --- a/webui_st.py +++ b/webui_st.py @@ -1,5 +1,5 @@ import streamlit as st -# from st_btn_select import st_btn_select +from streamlit_chatbox import st_chatbox import tempfile ###### 从webui借用的代码 ##### ###### 做了少量修改 ##### @@ -23,6 +23,7 @@ def get_vs_list(): if not os.path.exists(KB_ROOT_PATH): return lst_default lst = os.listdir(KB_ROOT_PATH) + lst = [x for x in lst if os.path.isdir(os.path.join(KB_ROOT_PATH, x))] if not lst: return lst_default lst.sort() @@ -31,7 +32,6 @@ def get_vs_list(): embedding_model_dict_list = list(embedding_model_dict.keys()) llm_model_dict_list = list(llm_model_dict.keys()) -# flag_csv_logger = gr.CSVLogger() def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, @@ -50,6 +50,9 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR history[-1][-1] += source yield history, "" elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path): + local_doc_qa.top_k = vector_search_top_k + local_doc_qa.chunk_conent = chunk_conent + local_doc_qa.chunk_size = chunk_size for resp, history in local_doc_qa.get_knowledge_based_answer( query=query, vs_path=vs_path, chat_history=history, streaming=streaming): source = "\n\n" @@ -85,62 +88,16 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR yield history + [[query, "请选择知识库后进行测试,当前未选择知识库。"]], "" else: - for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history, - streaming=streaming): + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": query, "history": history, "streaming": streaming}) + for answer_result in answer_result_stream_result['answer_result_stream']: resp = answer_result.llm_output["answer"] history = answer_result.history history[-1][-1] = resp + ( "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") yield history, "" logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}") - # flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME) - - -def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'): - local_doc_qa = LocalDocQA() - # 初始化消息 - args = parser.parse_args() - args_dict = vars(args) - args_dict.update(model=llm_model) - shared.loaderCheckPoint = LoaderCheckPoint(args_dict) - llm_model_ins = shared.loaderLLM() - llm_model_ins.set_history_len(LLM_HISTORY_LEN) - - try: - local_doc_qa.init_cfg(llm_model=llm_model_ins, - embedding_model=embedding_model) - generator = local_doc_qa.llm.generatorAnswer("你好") - for answer_result in generator: - print(answer_result.llm_output) - reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" - logger.info(reply) - except Exception as e: - logger.error(e) - reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮""" - if str(e) == "Unknown platform: darwin": - logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:" - " https://github.com/imClumsyPanda/langchain-ChatGLM") - else: - logger.info(reply) - return local_doc_qa - - -# 暂未使用到,先保留 -# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history): -# try: -# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) -# llm_model_ins.history_len = llm_history_len -# local_doc_qa.init_cfg(llm_model=llm_model_ins, -# embedding_model=embedding_model, -# top_k=top_k) -# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" -# logger.info(model_status) -# except Exception as e: -# logger.error(e) -# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮""" -# logger.info(model_status) -# return history + [[None, model_status]] def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): @@ -148,7 +105,8 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte filelist = [] if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content")) - if local_doc_qa.llm and local_doc_qa.embeddings: + qa = st.session_state.local_doc_qa + if qa.llm_model_chain and qa.embeddings: if isinstance(files, list): for file in files: filename = os.path.split(file.name)[-1] @@ -156,10 +114,10 @@ def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_conte KB_ROOT_PATH, vs_id, "content", filename)) filelist.append(os.path.join( KB_ROOT_PATH, vs_id, "content", filename)) - vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store( + vs_path, loaded_files = qa.init_knowledge_vector_store( filelist, vs_path, sentence_size) else: - vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, + vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, sentence_size) if len(loaded_files): file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" @@ -177,10 +135,7 @@ knowledge_base_test_mode_info = ("【注意】\n\n" "并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n" "2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。" """3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n""" - "4. 单条内容长度建议设置在100-150左右。\n\n" - "5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中," - "本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。" - "相关参数将在后续版本中支持本界面直接修改。") + "4. 单条内容长度建议设置在100-150左右。") webui_title = """ @@ -192,7 +147,7 @@ webui_title = """ ###### todo ##### # 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。 -# 目前已经实现了local_doc_qa的全局化,后面要考虑shared。 +# 目前已经实现了local_doc_qa和shared.loaderCheckPoint的全局化。 # 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。 # 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。 # 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。 @@ -201,25 +156,11 @@ webui_title = """ ###### 配置项 ##### class ST_CONFIG: - user_bg_color = '#77ff77' - user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7' - robot_bg_color = '#ccccee' - robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0' - default_mode = '知识库问答' - defalut_kb = '' + default_mode = "知识库问答" + default_kb = "" ###### ##### -class MsgType: - ''' - 目前仅支持文本类型的输入输出,为以后多模态模型预留图像、视频、音频支持。 - ''' - TEXT = 1 - IMAGE = 2 - VIDEO = 3 - AUDIO = 4 - - class TempFile: ''' 为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式 @@ -229,132 +170,54 @@ class TempFile: self.name = path -def init_session(): - st.session_state.setdefault('history', []) - - -# def get_query_params(): -# ''' -# 可以用url参数传递配置参数:llm_model, embedding_model, kb, mode。 -# 该参数将覆盖model_config中的配置。处于安全考虑,目前只支持kb和mode -# 方便将固定的配置分享给特定的人。 -# ''' -# params = st.experimental_get_query_params() -# return {k: v[0] for k, v in params.items() if v} - - -def robot_say(msg, kb=''): - st.session_state['history'].append( - {'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb}) - - -def user_say(msg): - st.session_state['history'].append( - {'is_user': True, 'type': MsgType.TEXT, 'content': msg}) - - -def format_md(msg, is_user=False, bg_color='', margin='10%'): - ''' - 将文本消息格式化为markdown文本 - ''' - if is_user: - bg_color = bg_color or ST_CONFIG.user_bg_color - text = f''' -
- {msg} -
- ''' - else: - bg_color = bg_color or ST_CONFIG.robot_bg_color - text = f''' -
- {msg} -
- ''' - return text - - -def message(msg, - is_user=False, - msg_type=MsgType.TEXT, - icon='', - bg_color='', - margin='10%', - kb='', - ): - ''' - 渲染单条消息。目前仅支持文本 - ''' - cols = st.columns([1, 10, 1]) - empty = cols[1].empty() - if is_user: - icon = icon or ST_CONFIG.user_icon - bg_color = bg_color or ST_CONFIG.user_bg_color - cols[2].image(icon, width=40) - if msg_type == MsgType.TEXT: - text = format_md(msg, is_user, bg_color, margin) - empty.markdown(text, unsafe_allow_html=True) - else: - raise RuntimeError('only support text message now.') - else: - icon = icon or ST_CONFIG.robot_icon - bg_color = bg_color or ST_CONFIG.robot_bg_color - cols[0].image(icon, width=40) - if kb: - cols[0].write(f'({kb})') - if msg_type == MsgType.TEXT: - text = format_md(msg, is_user, bg_color, margin) - empty.markdown(text, unsafe_allow_html=True) - else: - raise RuntimeError('only support text message now.') - return empty - - -def output_messages( - user_bg_color='', - robot_bg_color='', - user_icon='', - robot_icon='', -): - with chat_box.container(): - last_response = None - for msg in st.session_state['history']: - bg_color = user_bg_color if msg['is_user'] else robot_bg_color - icon = user_icon if msg['is_user'] else robot_icon - empty = message(msg['content'], - is_user=msg['is_user'], - icon=icon, - msg_type=msg['type'], - bg_color=bg_color, - kb=msg.get('kb', '') - ) - if not msg['is_user']: - last_response = empty - return last_response - - @st.cache_resource(show_spinner=False, max_entries=1) -def load_model(llm_model: str, embedding_model: str): +def load_model( + llm_model: str = LLM_MODEL, + embedding_model: str = EMBEDDING_MODEL, + use_ptuning_v2: bool = USE_PTUNING_V2, +): ''' 对应init_model,利用streamlit cache避免模型重复加载 ''' - local_doc_qa = init_model(llm_model, embedding_model) - robot_say('模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。\n请尽量不要刷新页面,以免模型出错或重复加载。') + local_doc_qa = LocalDocQA() + # 初始化消息 + args = parser.parse_args() + args_dict = vars(args) + args_dict.update(model=llm_model) + if shared.loaderCheckPoint is None: # avoid checkpoint reloading when reinit model + shared.loaderCheckPoint = LoaderCheckPoint(args_dict) + # shared.loaderCheckPoint.model_name is different by no_remote_model. + # if it is not set properly error occurs when reinit llm model(issue#473). + # as no_remote_model is removed from model_config, need workaround to set it automaticlly. + local_model_path = llm_model_dict.get(llm_model, {}).get('local_model_path') or '' + no_remote_model = os.path.isdir(local_model_path) + llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) + llm_model_ins.history_len = LLM_HISTORY_LEN + + try: + local_doc_qa.init_cfg(llm_model=llm_model_ins, + embedding_model=embedding_model) + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": "你好", "history": [], "streaming": False}) + + for answer_result in answer_result_stream_result['answer_result_stream']: + print(answer_result.llm_output) + reply = """模型已成功加载,可以开始对话,或从右侧选择模式后开始对话""" + logger.info(reply) + except Exception as e: + logger.error(e) + reply = """模型未成功加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮""" + if str(e) == "Unknown platform: darwin": + logger.info("该报错可能因为您使用的是 macOS 操作系统,需先下载模型至本地后执行 Web UI,具体方法请参考项目 README 中本地部署方法及常见问题:" + " https://github.com/imClumsyPanda/langchain-ChatGLM") + else: + logger.info(reply) return local_doc_qa # @st.cache_data def answer(query, vs_path='', history=[], mode='', score_threshold=0, - vector_search_top_k=5, chunk_conent=True, chunk_size=100, qa=None + vector_search_top_k=5, chunk_conent=True, chunk_size=100 ): ''' 对应get_answer,--利用streamlit cache缓存相同问题的答案-- @@ -363,48 +226,24 @@ def answer(query, vs_path='', history=[], mode='', score_threshold=0, vector_search_top_k, chunk_conent, chunk_size) -def load_vector_store( - vs_id, - files, - sentence_size=100, - history=[], - one_conent=None, - one_content_segmentation=None, -): - return get_vector_store( - local_doc_qa, - vs_id, - files, - sentence_size, - history, - one_conent, - one_content_segmentation, - ) +def use_kb_mode(m): + return m in ["知识库问答", "知识库测试"] # main ui st.set_page_config(webui_title, layout='wide') -init_session() -# params = get_query_params() -# llm_model = params.get('llm_model', LLM_MODEL) -# embedding_model = params.get('embedding_model', EMBEDDING_MODEL) - -with st.spinner(f'正在加载模型({LLM_MODEL} + {EMBEDDING_MODEL}),请耐心等候...'): - local_doc_qa = load_model(LLM_MODEL, EMBEDDING_MODEL) - - -def use_kb_mode(m): - return m in ['知识库问答', '知识库测试'] +chat_box = st_chatbox(greetings=["模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。"]) +# 使用 help(st_chatbox) 查看自定义参数 # sidebar modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试'] with st.sidebar: def on_mode_change(): m = st.session_state.mode - robot_say(f'已切换到"{m}"模式') + chat_box.robot_say(f'已切换到"{m}"模式') if m == '知识库测试': - robot_say(knowledge_base_test_mode_info) + chat_box.robot_say(knowledge_base_test_mode_info) index = 0 try: @@ -414,7 +253,7 @@ with st.sidebar: mode = st.selectbox('对话模式', modes, index, on_change=on_mode_change, key='mode') - with st.expander('模型配置', '知识' not in mode): + with st.expander('模型配置', not use_kb_mode(mode)): with st.form('model_config'): index = 0 try: @@ -423,9 +262,8 @@ with st.sidebar: pass llm_model = st.selectbox('LLM模型', llm_model_dict_list, index) - no_remote_model = st.checkbox('加载本地模型', False) use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False) - use_lora = st.checkbox('使用lora微调的权重', False) + try: index = embedding_model_dict_list.index(EMBEDDING_MODEL) except: @@ -435,42 +273,52 @@ with st.sidebar: btn_load_model = st.form_submit_button('重新加载模型') if btn_load_model: - local_doc_qa = load_model(llm_model, embedding_model) + local_doc_qa = load_model(llm_model, embedding_model, use_ptuning_v2) - if mode in ['知识库问答', '知识库测试']: + history_len = st.slider( + "LLM对话轮数", 1, 50, LLM_HISTORY_LEN) + + if use_kb_mode(mode): vs_list = get_vs_list() vs_list.remove('新建知识库') def on_new_kb(): name = st.session_state.kb_name - if name in vs_list: - st.error(f'名为“{name}”的知识库已存在。') + if not name: + st.sidebar.error(f'新建知识库名称不能为空!') + elif name in vs_list: + st.sidebar.error(f'名为“{name}”的知识库已存在。') else: - vs_list.append(name) st.session_state.vs_path = name + st.session_state.kb_name = '' + new_kb_dir = os.path.join(KB_ROOT_PATH, name) + if not os.path.exists(new_kb_dir): + os.makedirs(new_kb_dir) + st.sidebar.success(f'名为“{name}”的知识库创建成功,您可以开始添加文件。') def on_vs_change(): - robot_say(f'已加载知识库: {st.session_state.vs_path}') + chat_box.robot_say(f'已加载知识库: {st.session_state.vs_path}') with st.expander('知识库配置', True): cols = st.columns([12, 10]) kb_name = cols[0].text_input( - '新知识库名称', placeholder='新知识库名称', label_visibility='collapsed') + '新知识库名称', placeholder='新知识库名称', label_visibility='collapsed', key='kb_name') cols[1].button('新建知识库', on_click=on_new_kb) + index = 0 + try: + index = vs_list.index(ST_CONFIG.default_kb) + except: + pass vs_path = st.selectbox( - '选择知识库', vs_list, on_change=on_vs_change, key='vs_path') + '选择知识库', vs_list, index, on_change=on_vs_change, key='vs_path') st.text('') score_threshold = st.slider( '知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD) top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K) - history_len = st.slider( - 'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置 - local_doc_qa.llm.set_history_len(history_len) chunk_conent = st.checkbox('启用上下文关联', False) - st.text('') - # chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库 chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE) + st.text('') sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE) files = st.file_uploader('上传知识文件', ['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'], @@ -483,56 +331,61 @@ with st.sidebar: with open(file, 'wb') as fp: fp.write(f.getvalue()) file_list.append(TempFile(file)) - _, _, history = load_vector_store( + _, _, history = get_vector_store( vs_path, file_list, sentence_size, [], None, None) st.session_state.files = [] -# main body -chat_box = st.empty() +# load model after params rendered +with st.spinner(f"正在加载模型({llm_model} + {embedding_model}),请耐心等候..."): + local_doc_qa = load_model( + llm_model, + embedding_model, + use_ptuning_v2, + ) + local_doc_qa.llm_model_chain.history_len = history_len + if use_kb_mode(mode): + local_doc_qa.chunk_conent = chunk_conent + local_doc_qa.chunk_size = chunk_size + # local_doc_qa.llm_model_chain.temperature = temperature # 这样设置temperature似乎不起作用 + st.session_state.local_doc_qa = local_doc_qa -with st.form('my_form', clear_on_submit=True): +# input form +with st.form("my_form", clear_on_submit=True): cols = st.columns([8, 1]) - question = cols[0].text_input( + question = cols[0].text_area( 'temp', key='input_question', label_visibility='collapsed') - def on_send(): - q = st.session_state.input_question - if q: - user_say(q) + if cols[1].form_submit_button("发送"): + chat_box.user_say(question) + history = [] + if mode == "LLM 对话": + chat_box.robot_say("正在思考...") + chat_box.output_messages() + for history, _ in answer(question, + history=[], + mode=mode): + chat_box.update_last_box_text(history[-1][-1]) + elif use_kb_mode(mode): + chat_box.robot_say(f"正在查询 [{vs_path}] ...") + chat_box.output_messages() + for history, _ in answer(question, + vs_path=os.path.join( + KB_ROOT_PATH, vs_path, 'vector_store'), + history=[], + mode=mode, + score_threshold=score_threshold, + vector_search_top_k=top_k, + chunk_conent=chunk_conent, + chunk_size=chunk_size): + chat_box.update_last_box_text(history[-1][-1]) + else: + chat_box.robot_say(f"正在执行Bing搜索...") + chat_box.output_messages() + for history, _ in answer(question, + history=[], + mode=mode): + chat_box.update_last_box_text(history[-1][-1]) - if mode == 'LLM 对话': - robot_say('正在思考...') - last_response = output_messages() - for history, _ in answer(q, - history=[], - mode=mode): - last_response.markdown( - format_md(history[-1][-1], False), - unsafe_allow_html=True - ) - elif use_kb_mode(mode): - robot_say('正在思考...', vs_path) - last_response = output_messages() - for history, _ in answer(q, - vs_path=os.path.join( - KB_ROOT_PATH, vs_path, "vector_store"), - history=[], - mode=mode, - score_threshold=score_threshold, - vector_search_top_k=top_k, - chunk_conent=chunk_conent, - chunk_size=chunk_size): - last_response.markdown( - format_md(history[-1][-1], False, 'ligreen'), - unsafe_allow_html=True - ) - else: - robot_say('正在思考...') - last_response = output_messages() - st.session_state['history'][-1]['content'] = history[-1][-1] - submit = cols[1].form_submit_button('发送', on_click=on_send) - -output_messages() - -# st.write(st.session_state['history']) +# st.write(chat_box.history) +chat_box.output_messages()