From c3627de005e21a0c4e6307e292558841cf66d440 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Fri, 18 Aug 2023 08:48:02 +0800 Subject: [PATCH 01/13] fix startup.py: add log info before server starting --- startup.py | 48 +++++++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/startup.py b/startup.py index 830c3229..b3e85800 100644 --- a/startup.py +++ b/startup.py @@ -315,6 +315,31 @@ def parse_args() -> argparse.ArgumentParser: return args +def dump_server_info(after_start=False): + print("\n\n") + print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) + print(f"操作系统:{platform.platform()}.") + print(f"python版本:{sys.version}") + print(f"项目版本:{VERSION}") + print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") + print("\n") + print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}") + pprint(llm_model_dict[LLM_MODEL]) + print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") + if after_start: + print("\n") + print(f"服务端运行信息:") + if args.openai_api: + print(f" OpenAI API Server: {fschat_openai_api_address()}/v1") + print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)") + if args.api: + print(f" Chatchat API Server: {api_address()}") + if args.webui: + print(f" Chatchat WEBUI Server: {webui_address()}") + print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) + print("\n\n") + + if __name__ == "__main__": import platform import time @@ -343,6 +368,7 @@ if __name__ == "__main__": args.api = False args.webui = False + dump_server_info() logger.info(f"正在启动服务:") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") @@ -403,27 +429,7 @@ if __name__ == "__main__": no = queue.get() if no == len(processes): time.sleep(0.5) - print("\n\n") - print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) - print(f"操作系统:{platform.platform()}.") - print(f"python版本:{sys.version}") - print(f"项目版本:{VERSION}") - print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") - print("\n") - print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}") - pprint(llm_model_dict[LLM_MODEL]) - print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") - print("\n") - print(f"服务端运行信息:") - if args.openai_api: - print(f" OpenAI API Server: {fschat_openai_api_address()}/v1") - print("请确认llm_model_dict中配置的api_base_url与上面地址一致。") - if args.api: - print(f" Chatchat API Server: {api_address()}") - if args.webui: - print(f" Chatchat WEBUI Server: {webui_address()}") - print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) - print("\n\n") + dump_server_info(True) break else: queue.put(no) From d4cf77170a7378c392dc2328540839e27c525fbc Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Fri, 18 Aug 2023 11:47:49 +0800 Subject: [PATCH 02/13] fix webui: correct error messages --- webui_pages/knowledge_base/knowledge_base.py | 14 +++++------ webui_pages/utils.py | 25 ++++++++++++++++---- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index f0594754..89e274ca 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -118,7 +118,7 @@ def knowledge_base_page(api: ApiRequest): vector_store_type=vs_type, embed_model=embed_model, ) - st.toast(ret["msg"]) + st.toast(ret.get("msg", " ")) st.session_state["selected_kb_name"] = kb_name st.experimental_rerun() @@ -140,10 +140,10 @@ def knowledge_base_page(api: ApiRequest): ): for f in files: ret = api.upload_kb_doc(f, kb) - if ret["code"] == 200: - st.toast(ret["msg"], icon="✔") - else: - st.toast(ret["msg"], icon="✖") + if msg := check_success_msg(ret): + st.toast(msg, icon="✔") + elif msg := check_error_msg(ret): + st.toast(msg, icon="✖") st.session_state.files = [] st.divider() @@ -235,7 +235,7 @@ def knowledge_base_page(api: ApiRequest): ): for row in selected_rows: ret = api.delete_kb_doc(kb, row["file_name"], True) - st.toast(ret["msg"]) + st.toast(ret.get("msg", " ")) st.experimental_rerun() st.divider() @@ -262,6 +262,6 @@ def knowledge_base_page(api: ApiRequest): use_container_width=True, ): ret = api.delete_knowledge_base(kb) - st.toast(ret["msg"]) + st.toast(ret.get("msg", " ")) time.sleep(1) st.experimental_rerun() diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 3e67ed72..cc38ef56 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -232,15 +232,15 @@ class ApiRequest: msg = f"无法连接API服务器,请确认已执行python server\\api.py" logger.error(msg) logger.error(e) - yield {"code": 500, "errorMsg": msg} + yield {"code": 500, "msg": msg} except httpx.ReadTimeout as e: msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')" logger.error(msg) logger.error(e) - yield {"code": 500, "errorMsg": msg} + yield {"code": 500, "msg": msg} except Exception as e: logger.error(e) - yield {"code": 500, "errorMsg": str(e)} + yield {"code": 500, "msg": str(e)} # 对话相关操作 @@ -394,7 +394,7 @@ class ApiRequest: return response.json() except Exception as e: logger.error(e) - return {"code": 500, "errorMsg": errorMsg or str(e)} + return {"code": 500, "msg": errorMsg or str(e)} def list_knowledge_bases( self, @@ -626,7 +626,22 @@ def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: ''' return error message if error occured when requests API ''' - if isinstance(data, dict) and key in data: + if isinstance(data, dict): + if key in data: + return data[key] + if "code" in data and data["code"] != 200: + return data["msg"] + return "" + + +def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str: + ''' + return error message if error occured when requests API + ''' + if (isinstance(data, dict) + and key in data + and "code" in data + and data["code"] == 200): return data[key] return "" From 95d9fb0ee968f969b2c44f997105a019425db050 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E9=B9=8F?= <37921589+shutter-cp@users.noreply.github.com> Date: Fri, 18 Aug 2023 14:41:11 +0800 Subject: [PATCH 03/13] fix bug 1159 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复使用使用pg存储向量时,向量检索错误 {TypeError}TypeError("'Document' object is not subscriptable") --- server/knowledge_base/kb_service/pg_kb_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index a3126ece..a3f9318e 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -46,7 +46,7 @@ class PGKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: # todo: support score threshold self._load_pg_vector(embeddings=embeddings) - return self.pg_vector.similarity_search(query, top_k) + return self.pg_vector.similarity_search_with_score(query, top_k) def add_doc(self, kb_file: KnowledgeFile): """ From fe9f2df17daeae4a659e7da170c61855509c4da1 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Fri, 18 Aug 2023 16:46:59 +0800 Subject: [PATCH 04/13] fix startup.py: correct command help info --- startup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/startup.py b/startup.py index b3e85800..10d53813 100644 --- a/startup.py +++ b/startup.py @@ -252,6 +252,7 @@ def run_webui(q: Queue, run_seq: int = 5): def parse_args() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument( + "-a", "--all-webui", action="store_true", help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py", @@ -260,13 +261,13 @@ def parse_args() -> argparse.ArgumentParser: parser.add_argument( "--all-api", action="store_true", - help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py", + help="run fastchat's controller/model_worker/openai_api servers, run api.py", dest="all_api", ) parser.add_argument( "--llm-api", action="store_true", - help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py", + help="run fastchat's controller/model_worker/openai_api serversy", dest="llm_api", ) parser.add_argument( From 62d6f44b28aa24793800108c78d05191e4378da9 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Fri, 18 Aug 2023 21:30:50 +0800 Subject: [PATCH 05/13] fix startup.py --- startup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/startup.py b/startup.py index 10d53813..88bdd1d4 100644 --- a/startup.py +++ b/startup.py @@ -255,33 +255,33 @@ def parse_args() -> argparse.ArgumentParser: "-a", "--all-webui", action="store_true", - help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py", + help="run fastchat's controller/openai_api/model_worker servers, run api.py and webui.py", dest="all_webui", ) parser.add_argument( "--all-api", action="store_true", - help="run fastchat's controller/model_worker/openai_api servers, run api.py", + help="run fastchat's controller/openai_api/model_worker servers, run api.py", dest="all_api", ) parser.add_argument( "--llm-api", action="store_true", - help="run fastchat's controller/model_worker/openai_api serversy", + help="run fastchat's controller/openai_api/model_worker servers", dest="llm_api", ) parser.add_argument( "-o", "--openai-api", action="store_true", - help="run fastchat controller/openai_api servers", + help="run fastchat's controller/openai_api servers", dest="openai_api", ) parser.add_argument( "-m", "--model-worker", action="store_true", - help="run fastchat model_worker server with specified model name. specify --model-name if not using default LLM_MODEL", + help="run fastchat's model_worker server with specified model name. specify --model-name if not using default LLM_MODEL", dest="model_worker", ) parser.add_argument( From 69627a2fa34717660974612d2b94e14ac76152de Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Sat, 19 Aug 2023 15:14:45 +0800 Subject: [PATCH 06/13] update chat and knowledge base api: unify exception processing and return types --- server/chat/github_chat.py | 109 ++++++++++++++ server/knowledge_base/kb_api.py | 12 +- server/knowledge_base/kb_doc_api.py | 141 +++++++++++-------- server/utils.py | 4 +- webui_pages/knowledge_base/knowledge_base.py | 8 +- webui_pages/utils.py | 2 +- 6 files changed, 209 insertions(+), 67 deletions(-) create mode 100644 server/chat/github_chat.py diff --git a/server/chat/github_chat.py b/server/chat/github_chat.py new file mode 100644 index 00000000..b1615488 --- /dev/null +++ b/server/chat/github_chat.py @@ -0,0 +1,109 @@ +from langchain.document_loaders.github import GitHubIssuesLoader +from fastapi import Body +from fastapi.responses import StreamingResponse +from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) +from server.chat.utils import wrap_done +from server.utils import BaseResponse +from langchain.chat_models import ChatOpenAI +from langchain import LLMChain +from langchain.callbacks import AsyncIteratorCallbackHandler +from typing import AsyncIterable +import asyncio +from langchain.prompts.chat import ChatPromptTemplate +from typing import List, Optional, Literal +from server.chat.utils import History +from langchain.docstore.document import Document +import json +import os +from functools import lru_cache +from datetime import datetime + + +GITHUB_PERSONAL_ACCESS_TOKEN = os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN") + + +@lru_cache(1) +def load_issues(tick: str): + ''' + set tick to a periodic value to refresh cache + ''' + loader = GitHubIssuesLoader( + repo="chatchat-space/langchain-chatglm", + access_token=GITHUB_PERSONAL_ACCESS_TOKEN, + include_prs=True, + state="all", + ) + docs = loader.load() + return docs + + +def +def github_chat(query: str = Body(..., description="用户输入", examples=["本项目最新进展"]), + top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), + include_prs: bool = Body(True, description="是否包含PR"), + state: Literal['open', 'closed', 'all'] = Body(None, description="Issue/PR状态"), + creator: str = Body(None, description="创建者"), + history: List[History] = Body([], + description="历史对话", + examples=[[ + {"role": "user", + "content": "介绍一下本项目"}, + {"role": "assistant", + "content": "LangChain-Chatchat (原 Langchain-ChatGLM): 基于 Langchain 与 ChatGLM 等大语言模型的本地知识库问答应用实现。"}]] + ), + stream: bool = Body(False, description="流式输出"), + ): + if GITHUB_PERSONAL_ACCESS_TOKEN is None: + return BaseResponse(code=404, msg=f"使用本功能需要 GITHUB_PERSONAL_ACCESS_TOKEN") + + async def chat_iterator(query: str, + search_engine_name: str, + top_k: int, + history: Optional[List[History]], + ) -> AsyncIterable[str]: + callback = AsyncIteratorCallbackHandler() + model = ChatOpenAI( + streaming=True, + verbose=True, + callbacks=[callback], + openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], + openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], + model_name=LLM_MODEL + ) + + docs = lookup_search_engine(query, search_engine_name, top_k) + context = "\n".join([doc.page_content for doc in docs]) + + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) + + chain = LLMChain(prompt=chat_prompt, llm=model) + + # Begin a task that runs in the background. + task = asyncio.create_task(wrap_done( + chain.acall({"context": context, "question": query}), + callback.done), + ) + + source_documents = [ + f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" + for inum, doc in enumerate(docs) + ] + + if stream: + async for token in callback.aiter(): + # Use server-sent-events to stream the response + yield json.dumps({"answer": token, + "docs": source_documents}, + ensure_ascii=False) + else: + answer = "" + async for token in callback.aiter(): + answer += token + yield json.dumps({"answer": token, + "docs": source_documents}, + ensure_ascii=False) + await task + + return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history), + media_type="text/event-stream") diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 4753ba4d..b9151b89 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -15,7 +15,7 @@ async def list_kbs(): async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), vector_store_type: str = Body("faiss"), embed_model: str = Body(EMBEDDING_MODEL), - ): + ) -> BaseResponse: # Create selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -27,13 +27,18 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) - kb.create_kb() + try: + kb.create_kb() + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"创建知识库出错: {e}") + return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") async def delete_kb( knowledge_base_name: str = Body(..., examples=["samples"]) - ): + ) -> BaseResponse: # Delete selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -51,5 +56,6 @@ async def delete_kb( return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") except Exception as e: print(e) + return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}") return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 0bf2cb76..0d74fd6c 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -22,7 +22,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" ) -> List[DocumentWithScore]: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: - return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []} + return [] docs = kb.search_docs(query, top_k, score_threshold) data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] @@ -31,7 +31,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" async def list_docs( knowledge_base_name: str -): +) -> ListResponse: if not validate_kb_name(knowledge_base_name): return ListResponse(code=403, msg="Don't attack me", data=[]) @@ -41,13 +41,13 @@ async def list_docs( return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: all_doc_names = kb.list_docs() - return ListResponse(data=all_doc_names) + return ListResponse(data=all_doc_names) async def upload_doc(file: UploadFile = File(..., description="上传文件"), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), override: bool = Form(False, description="覆盖已有文件"), - ): + ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -57,31 +57,37 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), file_content = await file.read() # 读取上传文件的内容 - kb_file = KnowledgeFile(filename=file.filename, - knowledge_base_name=knowledge_base_name) - - if (os.path.exists(kb_file.filepath) - and not override - and os.path.getsize(kb_file.filepath) == len(file_content) - ): - # TODO: filesize 不同后的处理 - file_status = f"文件 {kb_file.filename} 已存在。" - return BaseResponse(code=404, msg=file_status) - try: + kb_file = KnowledgeFile(filename=file.filename, + knowledge_base_name=knowledge_base_name) + + if (os.path.exists(kb_file.filepath) + and not override + and os.path.getsize(kb_file.filepath) == len(file_content) + ): + # TODO: filesize 不同后的处理 + file_status = f"文件 {kb_file.filename} 已存在。" + return BaseResponse(code=404, msg=file_status) + with open(kb_file.filepath, "wb") as f: f.write(file_content) except Exception as e: + print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") - kb.add_doc(kb_file) + try: + kb.add_doc(kb_file) + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") + return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), doc_name: str = Body(..., examples=["file_name.md"]), delete_content: bool = Body(False), - ): + ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -92,17 +98,22 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), if not kb.exist_doc(doc_name): return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") - kb_file = KnowledgeFile(filename=doc_name, - knowledge_base_name=knowledge_base_name) - kb.delete_doc(kb_file, delete_content) + + try: + kb_file = KnowledgeFile(filename=doc_name, + knowledge_base_name=knowledge_base_name) + kb.delete_doc(kb_file, delete_content) + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") + return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") - # return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败") async def update_doc( knowledge_base_name: str = Body(..., examples=["samples"]), file_name: str = Body(..., examples=["file_name"]), - ): + ) -> BaseResponse: ''' 更新知识库文档 ''' @@ -113,14 +124,17 @@ async def update_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) + try: + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + if os.path.exists(kb_file.filepath): + kb.update_doc(kb_file) + return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}") - if os.path.exists(kb_file.filepath): - kb.update_doc(kb_file) - return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") - else: - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") async def download_doc( @@ -137,18 +151,20 @@ async def download_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) - - if os.path.exists(kb_file.filepath): - return FileResponse( - path=kb_file.filepath, - filename=kb_file.filename, - media_type="multipart/form-data") - else: - return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") + try: + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + if os.path.exists(kb_file.filepath): + return FileResponse( + path=kb_file.filepath, + filename=kb_file.filename, + media_type="multipart/form-data") + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}") + return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") async def recreate_vector_store( @@ -163,24 +179,33 @@ async def recreate_vector_store( by default, get_service_by_name only return knowledge base in the info.db and having document files in it. set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. ''' - kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) - if not kb.exists() and not allow_empty_kb: - return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - async def output(kb): - kb.create_kb() - kb.clear_vs() - docs = list_docs_from_folder(knowledge_base_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, knowledge_base_name) - yield json.dumps({ - "total": len(docs), - "finished": i, - "doc": doc, - }, ensure_ascii=False) - kb.add_doc(kb_file) - except Exception as e: - print(e) + async def output(): + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists() and not allow_empty_kb: + yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} + else: + kb.create_kb() + kb.clear_vs() + docs = list_docs_from_folder(knowledge_base_name) + for i, doc in enumerate(docs): + try: + kb_file = KnowledgeFile(doc, knowledge_base_name) + yield json.dumps({ + "code": 200, + "msg": f"({i + 1} / {len(docs)}): {doc}", + "total": len(docs), + "finished": i, + "doc": doc, + }, ensure_ascii=False) + kb.add_doc(kb_file) + except Exception as e: + print(e) + yield json.dumps({ + "code": 500, + "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", + }) + import asyncio + await asyncio.sleep(5) - return StreamingResponse(output(kb), media_type="text/event-stream") + return StreamingResponse(output(), media_type="text/event-stream") diff --git a/server/utils.py b/server/utils.py index c0f11a5f..4a887225 100644 --- a/server/utils.py +++ b/server/utils.py @@ -9,8 +9,8 @@ from typing import Any, Optional class BaseResponse(BaseModel): - code: int = pydantic.Field(200, description="HTTP status code") - msg: str = pydantic.Field("success", description="HTTP status message") + code: int = pydantic.Field(200, description="API status code") + msg: str = pydantic.Field("success", description="API status message") class Config: schema_extra = { diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 89e274ca..3bd531a9 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -249,12 +249,14 @@ def knowledge_base_page(api: ApiRequest): use_container_width=True, type="primary", ): - with st.spinner("向量库重构中"): + with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): empty = st.empty() empty.progress(0.0, "") for d in api.recreate_vector_store(kb): - print(d) - empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") + if msg := check_error_msg(d): + st.toast(msg) + else: + empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") st.experimental_rerun() if cols[2].button( diff --git a/webui_pages/utils.py b/webui_pages/utils.py index cc38ef56..18a24e4e 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -229,7 +229,7 @@ class ApiRequest: elif chunk.strip(): yield chunk except httpx.ConnectError as e: - msg = f"无法连接API服务器,请确认已执行python server\\api.py" + msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。" logger.error(msg) logger.error(e) yield {"code": 500, "msg": msg} From 956237feac6ebd9c98a76531b3281170c843841f Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Sat, 19 Aug 2023 15:19:01 +0800 Subject: [PATCH 07/13] add api tests --- tests/api/test_kb_api.py | 200 ++++++++++++++++++++++++++++++ tests/api/test_stream_chat_api.py | 108 ++++++++++++++++ 2 files changed, 308 insertions(+) create mode 100644 tests/api/test_kb_api.py create mode 100644 tests/api/test_stream_chat_api.py diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py new file mode 100644 index 00000000..c09d57f0 --- /dev/null +++ b/tests/api/test_kb_api.py @@ -0,0 +1,200 @@ +from doctest import testfile +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from configs.server_config import api_address +from configs.model_config import VECTOR_SEARCH_TOP_K + +from pprint import pprint + + +api_base_url = api_address() + +kb = "kb_for_api_test" +test_files = { + "README.MD": str(root_path / "README.MD"), + "FAQ.MD": str(root_path / "docs" / "FAQ.MD") +} + + +def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): + url = api_base_url + api + print("\n删除知识库") + r = requests.post(url, json=kb) + data = r.json() + pprint(data) + + # check kb not exists anymore + url = api_base_url + "/knowledge_base/list_knowledge_bases" + print("\n获取知识库列表:") + r = requests.get(url) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] + + +def test_create_kb(api="/knowledge_base/create_knowledge_base"): + url = api_base_url + api + + print(f"\n尝试用空名称创建知识库:") + r = requests.post(url, json={"knowledge_base_name": " "}) + data = r.json() + pprint(data) + assert data["code"] == 404 + assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称" + + print(f"\n创建新知识库: {kb}") + r = requests.post(url, json={"knowledge_base_name": kb}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"已新增知识库 {kb}" + + print(f"\n尝试创建同名知识库: {kb}") + r = requests.post(url, json={"knowledge_base_name": kb}) + data = r.json() + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"已存在同名知识库 {kb}" + + +def test_list_kbs(api="/knowledge_base/list_knowledge_bases"): + url = api_base_url + api + print("\n获取知识库列表:") + r = requests.get(url) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb in data["data"] + + +def test_upload_doc(api="/knowledge_base/upload_doc"): + url = api_base_url + api + for name, path in test_files.items(): + print(f"\n上传知识文件: {name}") + data = {"knowledge_base_name": kb, "override": True} + files = {"file": (name, open(path, "rb"))} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"成功上传文件 {name}" + + for name, path in test_files.items(): + print(f"\n尝试重新上传知识文件: {name}, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + files = {"file": (name, open(path, "rb"))} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"文件 {name} 已存在。" + + for name, path in test_files.items(): + print(f"\n尝试重新上传知识文件: {name}, 覆盖") + data = {"knowledge_base_name": kb, "override": True} + files = {"file": (name, open(path, "rb"))} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"成功上传文件 {name}" + + +def test_list_docs(api="/knowledge_base/list_docs"): + url = api_base_url + api + print("\n获取知识库中文件列表:") + r = requests.get(url, params={"knowledge_base_name": kb}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) + for name in test_files: + assert name in data["data"] + + +def test_search_docs(api="/knowledge_base/search_docs"): + url = api_base_url + api + query = "介绍一下langchain-chatchat项目" + print("\n检索知识库:") + print(query) + r = requests.post(url, json={"knowledge_base_name": kb, "query": query}) + data = r.json() + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_update_doc(api="/knowledge_base/update_doc"): + url = api_base_url + api + for name, path in test_files.items(): + print(f"\n更新知识文件: {name}") + r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"成功更新文件 {name}" + + +def test_delete_doc(api="/knowledge_base/delete_doc"): + url = api_base_url + api + for name, path in test_files.items(): + print(f"\n删除知识文件: {name}") + r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"{name} 文件删除成功" + + url = api_base_url + "/knowledge_base/search_docs" + query = "介绍一下langchain-chatchat项目" + print("\n尝试检索删除后的检索知识库:") + print(query) + r = requests.post(url, json={"knowledge_base_name": kb, "query": query}) + data = r.json() + pprint(data) + assert isinstance(data, list) and len(data) == 0 + + +def test_recreate_vs(api="/knowledge_base/recreate_vector_store"): + url = api_base_url + api + print("\n重建知识库:") + r = requests.post(url, json={"knowledge_base_name": kb}, stream=True) + for chunk in r.iter_content(None): + data = json.loads(chunk) + assert isinstance(data, dict) + assert data["code"] == 200 + print(data["msg"]) + + url = api_base_url + "/knowledge_base/search_docs" + query = "本项目支持哪些文件格式?" + print("\n尝试检索重建后的检索知识库:") + print(query) + r = requests.post(url, json={"knowledge_base_name": kb, "query": query}) + data = r.json() + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"): + url = api_base_url + api + print("\n删除知识库") + r = requests.post(url, json=kb) + data = r.json() + pprint(data) + + # check kb not exists anymore + url = api_base_url + "/knowledge_base/list_knowledge_bases" + print("\n获取知识库列表:") + r = requests.get(url) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py new file mode 100644 index 00000000..56d32375 --- /dev/null +++ b/tests/api/test_stream_chat_api.py @@ -0,0 +1,108 @@ +import requests +import json +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from configs.server_config import API_SERVER, api_address + +from pprint import pprint + + +api_base_url = api_address() + + +def dump_input(d, title): + print("\n") + print("=" * 30 + title + " input " + "="*30) + pprint(d) + + +def dump_output(r, title): + print("\n") + print("=" * 30 + title + " output" + "="*30) + for line in r.iter_content(None, decode_unicode=True): + print(line, end="", flush=True) + + +headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json', +} + +data = { + "query": "请用100字左右的文字介绍自己", + "history": [ + { + "role": "user", + "content": "你好" + }, + { + "role": "assistant", + "content": "你好,我是 ChatGLM" + } + ], + "stream": True +} + + + +def test_chat_fastchat(api="/chat/fastchat"): + url = f"{api_base_url}{api}" + data2 = { + "stream": True, + "messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}] + } + dump_input(data2, api) + response = requests.post(url, headers=headers, json=data2, stream=True) + dump_output(response, api) + assert response.status_code == 200 + + +def test_chat_chat(api="/chat/chat"): + url = f"{api_base_url}{api}" + dump_input(data, api) + response = requests.post(url, headers=headers, json=data, stream=True) + dump_output(response, api) + assert response.status_code == 200 + + +def test_knowledge_chat(api="/chat/knowledge_base_chat"): + url = f"{api_base_url}{api}" + data = { + "query": "如何提问以获得高质量答案", + "knowledge_base_name": "samples", + "history": [ + { + "role": "user", + "content": "你好" + }, + { + "role": "assistant", + "content": "你好,我是 ChatGLM" + } + ], + "stream": True + } + dump_input(data, api) + response = requests.post(url, headers=headers, json=data, stream=True) + print("\n") + print("=" * 30 + api + " output" + "="*30) + first = True + for line in response.iter_content(None, decode_unicode=True): + data = json.loads(line) + if first: + for doc in data["docs"]: + print(doc) + first = False + print(data["answer"], end="", flush=True) + assert response.status_code == 200 + + +def test_search_engine_chat(api="/chat/search_engine_chat"): + url = f"{api_base_url}{api}" + for se in ["bing", "duckduckgo"]: + dump_input(data, api) + response = requests.post(url, json=data, stream=True) + dump_output(response, api) + assert response.status_code == 200 From d694652b874d7a7dd4f89de04a294c8b4ee93755 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sun, 20 Aug 2023 10:40:31 +0800 Subject: [PATCH 08/13] update VERSION --- configs/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/__init__.py b/configs/__init__.py index 6e0ad134..dc9dd40a 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -1,4 +1,4 @@ from .model_config import * from .server_config import * -VERSION = "v0.2.1" +VERSION = "v0.2.2-preview" From 150a78bfd9fc47457ded3f161475c838a1d3d6ca Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Sun, 20 Aug 2023 16:52:49 +0800 Subject: [PATCH 09/13] update kb_doc_api:make faiss cache working; delete vector store docs before add duplicate docs --- server/knowledge_base/kb_service/base.py | 1 + .../kb_service/faiss_kb_service.py | 58 ++++++++++++------- tests/api/test_kb_api.py | 6 +- 3 files changed, 43 insertions(+), 22 deletions(-) diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index d506f633..09766e68 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -77,6 +77,7 @@ class KBService(ABC): """ docs = kb_file.file2text() if docs: + self.delete_doc(kb_file) embeddings = self._load_embeddings() self.do_add_doc(docs, embeddings) status = add_doc_to_db(kb_file) diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 5c8376fd..5953c3a7 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -41,7 +41,23 @@ def load_vector_store( vs_path = get_vs_path(knowledge_base_name) if embeddings is None: embeddings = load_embeddings(embed_model, embed_device) - search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + + if not os.path.exists(vs_path): + os.makedirs(vs_path) + + if "index.faiss" in os.listdir(vs_path): + search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + else: + # create an empty vector store + doc = Document(page_content="init", metadata={}) + search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True) + ids = [k for k, v in search_index.docstore._dict.items()] + search_index.delete(ids) + search_index.save_local(vs_path) + + if tick == 0: # vector store is loaded first time + _VECTOR_STORE_TICKS[knowledge_base_name] = 0 + return search_index @@ -74,8 +90,10 @@ class FaissKBService(KBService): def do_create_kb(self): if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) + load_vector_store(self.kb_name) def do_drop_kb(self): + self.clear_vs() shutil.rmtree(self.kb_path) def do_search(self, @@ -94,37 +112,35 @@ class FaissKBService(KBService): docs: List[Document], embeddings: Embeddings, ): - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True) - vector_store.add_documents(docs) - torch_gc() - else: - if not os.path.exists(self.vs_path): - os.makedirs(self.vs_path) - vector_store = FAISS.from_documents( - docs, embeddings, normalize_L2=True) # docs 为Document列表 - torch_gc() + vector_store = load_vector_store(self.kb_name, + embeddings=embeddings, + tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + vector_store.add_documents(docs) + torch_gc() vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) def do_delete_doc(self, kb_file: KnowledgeFile): embeddings = self._load_embeddings() - if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): - vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True) - ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] - if len(ids) == 0: - return None - vector_store.delete(ids) - vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) - return True - else: + vector_store = load_vector_store(self.kb_name, + embeddings=embeddings, + tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + + ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] + if len(ids) == 0: return None + vector_store.delete(ids) + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) + + return True + def do_clear_vs(self): shutil.rmtree(self.vs_path) os.makedirs(self.vs_path) + refresh_vs_cache(self.kb_name) def exist_doc(self, file_name: str): if super().exist_doc(file_name): diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index c09d57f0..5a8b97d2 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -8,6 +8,7 @@ root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from configs.server_config import api_address from configs.model_config import VECTOR_SEARCH_TOP_K +from server.knowledge_base.utils import get_kb_path from pprint import pprint @@ -22,8 +23,11 @@ test_files = { def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): + if not Path(get_kb_path(kb)).exists(): + return + url = api_base_url + api - print("\n删除知识库") + print("\n测试知识库存在,需要删除") r = requests.post(url, json=kb) data = r.json() pprint(data) From f0bcb3105a1829b6d7c33545722a82de067c07b5 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sun, 20 Aug 2023 16:52:56 +0800 Subject: [PATCH 10/13] update langchain version requirements --- requirements.txt | 2 +- requirements_api.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 646a5c71..93908dde 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -langchain==0.0.257 +langchain==0.0.266 openai sentence_transformers fschat==0.2.24 diff --git a/requirements_api.txt b/requirements_api.txt index 1e13587c..f567f9f7 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,4 +1,4 @@ -langchain==0.0.257 +langchain==0.0.266 openai sentence_transformers fschat==0.2.24 From adbee9f77770a88238b73633cd2695a698cc4ad3 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sun, 20 Aug 2023 17:29:50 +0800 Subject: [PATCH 11/13] fix add_argument fault in startup.py --- startup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/startup.py b/startup.py index 88bdd1d4..519f4fc9 100644 --- a/startup.py +++ b/startup.py @@ -285,7 +285,7 @@ def parse_args() -> argparse.ArgumentParser: dest="model_worker", ) parser.add_argument( - "-n" + "-n", "--model-name", type=str, default=LLM_MODEL, @@ -293,7 +293,7 @@ def parse_args() -> argparse.ArgumentParser: dest="model_name", ) parser.add_argument( - "-c" + "-c", "--controller", type=str, help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER", From f40bb69224df37db171a6dbd05afba11b30a353d Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Sun, 20 Aug 2023 19:10:29 +0800 Subject: [PATCH 12/13] =?UTF-8?q?=E4=BC=98=E5=8C=96FAISS=E5=90=91=E9=87=8F?= =?UTF-8?q?=E5=BA=93=E5=A4=9A=E6=96=87=E4=BB=B6=E6=93=8D=E4=BD=9C=EF=BC=9B?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Drecreate=5Fvector=5Fstore=EF=BC=8C=E5=A4=A7?= =?UTF-8?q?=E9=87=8F=E6=96=87=E4=BB=B6=E6=97=B6=E4=B8=8D=E5=86=8D=E8=B6=85?= =?UTF-8?q?=E6=97=B6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/knowledge_base/kb_doc_api.py | 11 ++++++----- server/knowledge_base/kb_service/base.py | 14 +++++++------- .../kb_service/faiss_kb_service.py | 15 ++++++++++----- webui_pages/knowledge_base/knowledge_base.py | 6 ++++-- webui_pages/utils.py | 18 +++++++++++++++--- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 0d74fd6c..74edc98b 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -47,6 +47,7 @@ async def list_docs( async def upload_doc(file: UploadFile = File(..., description="上传文件"), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), override: bool = Form(False, description="覆盖已有文件"), + not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -76,7 +77,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") try: - kb.add_doc(kb_file) + kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) except Exception as e: print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") @@ -87,6 +88,7 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), doc_name: str = Body(..., examples=["file_name.md"]), delete_content: bool = Body(False), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -102,7 +104,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), try: kb_file = KnowledgeFile(filename=doc_name, knowledge_base_name=knowledge_base_name) - kb.delete_doc(kb_file, delete_content) + kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache) except Exception as e: print(e) return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") @@ -113,6 +115,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), async def update_doc( knowledge_base_name: str = Body(..., examples=["samples"]), file_name: str = Body(..., examples=["file_name"]), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: ''' 更新知识库文档 @@ -128,7 +131,7 @@ async def update_doc( kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) if os.path.exists(kb_file.filepath): - kb.update_doc(kb_file) + kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") except Exception as e: print(e) @@ -205,7 +208,5 @@ async def recreate_vector_store( "code": 500, "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", }) - import asyncio - await asyncio.sleep(5) return StreamingResponse(output(), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 09766e68..9af5b0e3 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -71,7 +71,7 @@ class KBService(ABC): status = delete_kb_from_db(self.kb_name) return status - def add_doc(self, kb_file: KnowledgeFile): + def add_doc(self, kb_file: KnowledgeFile, **kwargs): """ 向知识库添加文件 """ @@ -79,29 +79,29 @@ class KBService(ABC): if docs: self.delete_doc(kb_file) embeddings = self._load_embeddings() - self.do_add_doc(docs, embeddings) + self.do_add_doc(docs, embeddings, **kwargs) status = add_doc_to_db(kb_file) else: status = False return status - def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False): + def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs): """ 从知识库删除文件 """ - self.do_delete_doc(kb_file) + self.do_delete_doc(kb_file, **kwargs) status = delete_file_from_db(kb_file) if delete_content and os.path.exists(kb_file.filepath): os.remove(kb_file.filepath) return status - def update_doc(self, kb_file: KnowledgeFile): + def update_doc(self, kb_file: KnowledgeFile, **kwargs): """ 使用content中的文件更新向量库 """ if os.path.exists(kb_file.filepath): - self.delete_doc(kb_file) - return self.add_doc(kb_file) + self.delete_doc(kb_file, **kwargs) + return self.add_doc(kb_file, **kwargs) def exist_doc(self, file_name: str): return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 5953c3a7..9fccfa23 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -66,6 +66,7 @@ def refresh_vs_cache(kb_name: str): make vector store cache refreshed when next loading """ _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 + print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}") class FaissKBService(KBService): @@ -111,17 +112,20 @@ class FaissKBService(KBService): def do_add_doc(self, docs: List[Document], embeddings: Embeddings, + **kwargs, ): vector_store = load_vector_store(self.kb_name, embeddings=embeddings, tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) vector_store.add_documents(docs) torch_gc() - vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) + if not kwargs.get("not_refresh_vs_cache"): + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) def do_delete_doc(self, - kb_file: KnowledgeFile): + kb_file: KnowledgeFile, + **kwargs): embeddings = self._load_embeddings() vector_store = load_vector_store(self.kb_name, embeddings=embeddings, @@ -132,8 +136,9 @@ class FaissKBService(KBService): return None vector_store.delete(ids) - vector_store.save_local(self.vs_path) - refresh_vs_cache(self.kb_name) + if not kwargs.get("not_refresh_vs_cache"): + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) return True diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 3bd531a9..4351e956 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -138,8 +138,10 @@ def knowledge_base_page(api: ApiRequest): # use_container_width=True, disabled=len(files) == 0, ): - for f in files: - ret = api.upload_kb_doc(f, kb) + data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] + data[-1]["not_refresh_vs_cache"]=False + for k in data: + ret = api.upload_kb_doc(**k) if msg := check_success_msg(ret): st.toast(msg, icon="✔") elif msg := check_error_msg(ret): diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 18a24e4e..c666d458 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -496,6 +496,7 @@ class ApiRequest: knowledge_base_name: str, filename: str = None, override: bool = False, + not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' @@ -529,7 +530,11 @@ class ApiRequest: else: response = self.post( "/knowledge_base/upload_doc", - data={"knowledge_base_name": knowledge_base_name, "override": override}, + data={ + "knowledge_base_name": knowledge_base_name, + "override": override, + "not_refresh_vs_cache": not_refresh_vs_cache, + }, files={"file": (filename, file)}, ) return self._check_httpx_json_response(response) @@ -539,6 +544,7 @@ class ApiRequest: knowledge_base_name: str, doc_name: str, delete_content: bool = False, + not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' @@ -551,6 +557,7 @@ class ApiRequest: "knowledge_base_name": knowledge_base_name, "doc_name": doc_name, "delete_content": delete_content, + "not_refresh_vs_cache": not_refresh_vs_cache, } if no_remote_api: @@ -568,6 +575,7 @@ class ApiRequest: self, knowledge_base_name: str, file_name: str, + not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' @@ -583,7 +591,11 @@ class ApiRequest: else: response = self.post( "/knowledge_base/update_doc", - json={"knowledge_base_name": knowledge_base_name, "file_name": file_name}, + json={ + "knowledge_base_name": knowledge_base_name, + "file_name": file_name, + "not_refresh_vs_cache": not_refresh_vs_cache, + }, ) return self._check_httpx_json_response(response) @@ -617,7 +629,7 @@ class ApiRequest: "/knowledge_base/recreate_vector_store", json=data, stream=True, - timeout=False, + timeout=None, ) return self._httpx_stream2generator(response, as_json=True) From c571585ffd7216c78cfd34395d07c1f12f95ac44 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 21 Aug 2023 08:50:15 +0800 Subject: [PATCH 13/13] optimize recreate vector store: save vector store once after all docs parsed for FAISS --- init_database.py | 4 ++++ server/knowledge_base/kb_doc_api.py | 6 +++++- server/knowledge_base/migrate.py | 18 +++++++++++++++--- startup.py | 9 +++++---- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/init_database.py b/init_database.py index 61d00e1f..7fc84940 100644 --- a/init_database.py +++ b/init_database.py @@ -2,6 +2,8 @@ from server.knowledge_base.migrate import create_tables, folder2db, recreate_all from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path +from startup import dump_server_info + if __name__ == "__main__": import argparse @@ -21,6 +23,8 @@ if __name__ == "__main__": ) args = parser.parse_args() + dump_server_info() + create_tables() print("database talbes created") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 74edc98b..ae027c12 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -201,7 +201,11 @@ async def recreate_vector_store( "finished": i, "doc": doc, }, ensure_ascii=False) - kb.add_doc(kb_file) + if i == len(docs) - 1: + not_refresh_vs_cache = False + else: + not_refresh_vs_cache = True + kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) except Exception as e: print(e) yield json.dumps({ diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 1c023fa7..c96d3867 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -43,7 +43,11 @@ def folder2db( kb_file = KnowledgeFile(doc, kb_name) if callable(callback_before): callback_before(kb_file, i, docs) - kb.add_doc(kb_file) + if i == len(docs) - 1: + not_refresh_vs_cache = False + else: + not_refresh_vs_cache = True + kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) if callable(callback_after): callback_after(kb_file, i, docs) except Exception as e: @@ -67,7 +71,11 @@ def folder2db( kb_file = KnowledgeFile(doc, kb_name) if callable(callback_before): callback_before(kb_file, i, docs) - kb.update_doc(kb_file) + if i == len(docs) - 1: + not_refresh_vs_cache = False + else: + not_refresh_vs_cache = True + kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) if callable(callback_after): callback_after(kb_file, i, docs) except Exception as e: @@ -81,7 +89,11 @@ def folder2db( kb_file = KnowledgeFile(doc, kb_name) if callable(callback_before): callback_before(kb_file, i, docs) - kb.add_doc(kb_file) + if i == len(docs) - 1: + not_refresh_vs_cache = False + else: + not_refresh_vs_cache = True + kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) if callable(callback_after): callback_after(kb_file, i, docs) except Exception as e: diff --git a/startup.py b/startup.py index 519f4fc9..9207307c 100644 --- a/startup.py +++ b/startup.py @@ -317,6 +317,11 @@ def parse_args() -> argparse.ArgumentParser: def dump_server_info(after_start=False): + import platform + import langchain + import fastchat + from configs.server_config import api_address, webui_address + print("\n\n") print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30) print(f"操作系统:{platform.platform()}.") @@ -342,11 +347,7 @@ def dump_server_info(after_start=False): if __name__ == "__main__": - import platform import time - import langchain - import fastchat - from configs.server_config import api_address, webui_address mp.set_start_method("spawn") queue = Queue()