diff --git a/README.md b/README.md index 2883636e..67aa6ef8 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ $ python startup.py -a [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9) ### 项目交流群 -二维码 +二维码 🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 diff --git a/configs/__init__.py b/configs/__init__.py index a4bf7665..0c862507 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -5,4 +5,4 @@ from .server_config import * from .prompt_config import * -VERSION = "v0.2.8-preview" +VERSION = "v0.2.9-preview" diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index 9c727b77..c9ff4318 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -106,7 +106,7 @@ kbs_config = { # TextSplitter配置项,如果你不明白其中的含义,就不要修改。 text_splitter_dict = { "ChineseRecursiveTextSplitter": { - "source": "huggingface", ## 选择tiktoken则使用openai的方法 + "source": "huggingface", # 选择tiktoken则使用openai的方法 "tokenizer_name_or_path": "", }, "SpacyTextSplitter": { diff --git a/configs/model_config.py.example b/configs/model_config.py.example index e08b5a00..19084142 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -15,9 +15,11 @@ EMBEDDING_DEVICE = "auto" EMBEDDING_KEYWORD_FILE = "keywords.txt" EMBEDDING_MODEL_OUTPUT_PATH = "output" -# 要运行的 LLM 名称,可以包括本地模型和在线模型。 -# 第一个将作为 API 和 WEBUI 的默认模型 -LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] +# 要运行的 LLM 名称,可以包括本地模型和在线模型。列表中本地模型将在启动项目时全部加载。 +# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。 +# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。 +# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。 +LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat", # AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0]) Agent_MODEL = None @@ -112,10 +114,10 @@ ONLINE_LLM_MODEL = { "api_key": "", "provider": "AzureWorker", }, - + # 昆仑万维天工 API https://model-platform.tiangong.cn/ "tiangong-api": { - "version":"SkyChat-MegaVerse", + "version": "SkyChat-MegaVerse", "api_key": "", "secret_key": "", "provider": "TianGongWorker", @@ -163,6 +165,25 @@ MODEL_PATH = { "chatglm3-6b": "THUDM/chatglm3-6b", "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", + "chatglm3-6b-base": "THUDM/chatglm3-6b-base", + + "Qwen-1_8B": "Qwen/Qwen-1_8B", + "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", + "Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8", + "Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4", + + "Qwen-7B": "Qwen/Qwen-7B", + "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", + + "Qwen-14B": "Qwen/Qwen-14B", + "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", + "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", + "Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4", + + "Qwen-72B": "Qwen/Qwen-72B", + "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", + "Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8", + "Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4", "baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat", "baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat", @@ -204,18 +225,11 @@ MODEL_PATH = { "opt-66b": "facebook/opt-66b", "opt-iml-max-30b": "facebook/opt-iml-max-30b", - "Qwen-7B": "Qwen/Qwen-7B", - "Qwen-14B": "Qwen/Qwen-14B", - "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", - "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", - "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", # 确保已经安装了auto-gptq optimum flash-attn - "Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4", # 确保已经安装了auto-gptq optimum flash-attn - "agentlm-7b": "THUDM/agentlm-7b", "agentlm-13b": "THUDM/agentlm-13b", "agentlm-70b": "THUDM/agentlm-70b", - "Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", # 更多01-ai模型尚未进行测试。如果需要使用,请自行测试。 + "Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", }, } @@ -242,11 +256,11 @@ VLLM_MODEL_DICT = { "BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k", # 注意:bloom系列的tokenizer与model是分离的,因此虽然vllm支持,但与fschat框架不兼容 - # "bloom":"bigscience/bloom", - # "bloomz":"bigscience/bloomz", - # "bloomz-560m":"bigscience/bloomz-560m", - # "bloomz-7b1":"bigscience/bloomz-7b1", - # "bloomz-1b7":"bigscience/bloomz-1b7", + # "bloom": "bigscience/bloom", + # "bloomz": "bigscience/bloomz", + # "bloomz-560m": "bigscience/bloomz-560m", + # "bloomz-7b1": "bigscience/bloomz-7b1", + # "bloomz-1b7": "bigscience/bloomz-1b7", "internlm-7b": "internlm/internlm-7b", "internlm-chat-7b": "internlm/internlm-chat-7b", @@ -273,10 +287,23 @@ VLLM_MODEL_DICT = { "opt-66b": "facebook/opt-66b", "opt-iml-max-30b": "facebook/opt-iml-max-30b", + "Qwen-1_8B": "Qwen/Qwen-1_8B", + "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", + "Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8", + "Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4", + "Qwen-7B": "Qwen/Qwen-7B", - "Qwen-14B": "Qwen/Qwen-14B", "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", + + "Qwen-14B": "Qwen/Qwen-14B", "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", + "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", + "Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4", + + "Qwen-72B": "Qwen/Qwen-72B", + "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", + "Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8", + "Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4", "agentlm-7b": "THUDM/agentlm-7b", "agentlm-13b": "THUDM/agentlm-13b", diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index dd0dc2a0..6fb6996c 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -1,7 +1,6 @@ # prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号 # 本配置文件支持热加载,修改prompt模板后无需重启服务。 - # LLM对话支持的变量: # - input: 用户输入内容 @@ -17,125 +16,112 @@ # - input: 用户输入内容 # - agent_scratchpad: Agent的思维记录 -PROMPT_TEMPLATES = {} +PROMPT_TEMPLATES = { + "llm_chat": { + "default": + '{{ input }}', -PROMPT_TEMPLATES["llm_chat"] = { -"default": "{{ input }}", -"with_history": -"""The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. + "with_history": + 'The following is a friendly conversation between a human and an AI. ' + 'The AI is talkative and provides lots of specific details from its context. ' + 'If the AI does not know the answer to a question, it truthfully says it does not know.\n\n' + 'Current conversation:\n' + '{history}\n' + 'Human: {input}\n' + 'AI:', -Current conversation: -{history} -Human: {input} -AI:""", -"py": -""" -你是一个聪明的代码助手,请你给我写出简单的py代码。 \n -{{ input }} -""", -} - -PROMPT_TEMPLATES["knowledge_base_chat"] = { -"default": -""" -<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 -<已知信息>{{ context }}、 -<问题>{{ question }} -""", -"text": -""" -<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 -<已知信息>{{ context }}、 -<问题>{{ question }} -""", -"Empty": # 搜不到知识库的时候使用 -""" -请你回答我的问题: -{{ question }} -\n -""", -} -PROMPT_TEMPLATES["search_engine_chat"] = { -"default": -""" -<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 -<已知信息>{{ context }} -<问题>{{ question }} -""", -"search": -""" -<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 -<已知信息>{{ context }}、 -<问题>{{ question }} -""", -} -PROMPT_TEMPLATES["agent_chat"] = { -"default": -""" -Answer the following questions as best you can. If it is in order, you can use some tools appropriately.You have access to the following tools: - -{tools} - -Please note that the "知识库查询工具" is information about the "西交利物浦大学" ,and if a question is asked about it, you must answer with the knowledge base, -Please note that the "天气查询工具" can only be used once since Question begin. - -Use the following format: -Question: the input question you must answer1 -Thought: you should always think about what to do and what tools to use. -Action: the action to take, should be one of [{tool_names}] -Action Input: the input to the action -Observation: the result of the action -... (this Thought/Action/Action Input/Observation can be repeated zero or more times) -Thought: I now know the final answer -Final Answer: the final answer to the original input question -Begin! - -history: {history} - -Question: {input} - -Thought: {agent_scratchpad} -""", - -"ChatGLM3": -""" -You can answer using the tools, or answer directly using your knowledge without using the tools.Respond to the human as helpfully and accurately as possible. -You have access to the following tools: -{tools} -Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -Valid "action" values: "Final Answer" or [{tool_names}] -Provide only ONE action per $JSON_BLOB, as shown: - -``` -{{{{ - "action": $TOOL_NAME, - "action_input": $INPUT -}}}} -``` - -Follow this format: - -Question: input question to answer -Thought: consider previous and subsequent steps -Action: -``` -$JSON_BLOB -``` -Observation: action result -... (repeat Thought/Action/Observation N times) -Thought: I know what to respond -Action: -``` -{{{{ - "action": "Final Answer", - "action_input": "Final response to human" -}}}} -Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. - -history: {history} - -Question: {input} - -Thought: {agent_scratchpad} -""", + "py": + '你是一个聪明的代码助手,请你给我写出简单的py代码。 \n' + '{{ input }}', + }, + + + "knowledge_base_chat": { + "default": + '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,' + '不允许在答案中添加编造成分,答案请使用中文。 \n' + '<已知信息>{{ context }}\n' + '<问题>{{ question }}\n', + + "text": + '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 \n' + '<已知信息>{{ context }}\n' + '<问题>{{ question }}\n', + + "empty": # 搜不到知识库的时候使用 + '请你回答我的问题:\n' + '{{ question }}\n\n', + }, + + + "search_engine_chat": { + "default": + '<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。' + '如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 \n' + '<已知信息>{{ context }}\n' + '<问题>{{ question }}\n', + + "search": + '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 \n' + '<已知信息>{{ context }}\n' + '<问题>{{ question }}\n', + }, + + + "agent_chat": { + "default": + 'Answer the following questions as best you can. If it is in order, you can use some tools appropriately. ' + 'You have access to the following tools:\n\n' + '{tools}\n\n' + 'Use the following format:\n' + 'Question: the input question you must answer1\n' + 'Thought: you should always think about what to do and what tools to use.\n' + 'Action: the action to take, should be one of [{tool_names}]\n' + 'Action Input: the input to the action\n' + 'Observation: the result of the action\n' + '... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n' + 'Thought: I now know the final answer\n' + 'Final Answer: the final answer to the original input question\n' + 'Begin!\n\n' + 'history: {history}\n\n' + 'Question: {input}\n\n' + 'Thought: {agent_scratchpad}\n', + + "ChatGLM3": + 'You can answer using the tools, or answer directly using your knowledge without using the tools. ' + 'Respond to the human as helpfully and accurately as possible.\n' + 'You have access to the following tools:\n' + '{tools}\n' + 'Use a json blob to specify a tool by providing an action key (tool name) ' + 'and an action_input key (tool input).\n' + 'Valid "action" values: "Final Answer" or [{tool_names}]' + 'Provide only ONE action per $JSON_BLOB, as shown:\n\n' + '```\n' + '{{{{\n' + ' "action": $TOOL_NAME,\n' + ' "action_input": $INPUT\n' + '}}}}\n' + '```\n\n' + 'Follow this format:\n\n' + 'Question: input question to answer\n' + 'Thought: consider previous and subsequent steps\n' + 'Action:\n' + '```\n' + '$JSON_BLOB\n' + '```\n' + 'Observation: action result\n' + '... (repeat Thought/Action/Observation N times)\n' + 'Thought: I know what to respond\n' + 'Action:\n' + '```\n' + '{{{{\n' + ' "action": "Final Answer",\n' + ' "action_input": "Final response to human"\n' + '}}}}\n' + 'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. ' + 'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n' + 'history: {history}\n\n' + 'Question: {input}\n\n' + 'Thought: {agent_scratchpad}', + } } diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 2e27c400..7fa0c412 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -66,7 +66,7 @@ FSCHAT_MODEL_WORKERS = { # "no_register": False, # "embed_in_truncate": False, - # 以下为vllm_woker配置参数,注意使用vllm必须有gpu,仅在Linux测试通过 + # 以下为vllm_worker配置参数,注意使用vllm必须有gpu,仅在Linux测试通过 # tokenizer = model_path # 如果tokenizer与model_path不一致在此处添加 # 'tokenizer_mode':'auto', @@ -93,14 +93,14 @@ FSCHAT_MODEL_WORKERS = { }, # 可以如下示例方式更改默认配置 - # "Qwen-7B-Chat": { # 使用default中的IP和端口 + # "Qwen-1_8B-Chat": { # 使用default中的IP和端口 # "device": "cpu", # }, - "chatglm3-6b": { # 使用default中的IP和端口 + "chatglm3-6b": { # 使用default中的IP和端口 "device": "cuda", }, - #以下配置可以不用修改,在model_config中设置启动的模型 + # 以下配置可以不用修改,在model_config中设置启动的模型 "zhipu-api": { "port": 21001, }, diff --git a/document_loaders/myimgloader.py b/document_loaders/myimgloader.py index 86481924..e09c6172 100644 --- a/document_loaders/myimgloader.py +++ b/document_loaders/myimgloader.py @@ -1,13 +1,13 @@ from typing import List from langchain.document_loaders.unstructured import UnstructuredFileLoader +from document_loaders.ocr import get_ocr class RapidOCRLoader(UnstructuredFileLoader): def _get_elements(self) -> List: def img2text(filepath): - from rapidocr_onnxruntime import RapidOCR resp = "" - ocr = RapidOCR() + ocr = get_ocr() result, _ = ocr(filepath) if result: ocr_result = [line[1] for line in result] diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py index 6cb77267..51778b89 100644 --- a/document_loaders/mypdfloader.py +++ b/document_loaders/mypdfloader.py @@ -1,5 +1,6 @@ from typing import List from langchain.document_loaders.unstructured import UnstructuredFileLoader +from document_loaders.ocr import get_ocr import tqdm @@ -7,9 +8,8 @@ class RapidOCRPDFLoader(UnstructuredFileLoader): def _get_elements(self) -> List: def pdf2text(filepath): import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆 - from rapidocr_onnxruntime import RapidOCR import numpy as np - ocr = RapidOCR() + ocr = get_ocr() doc = fitz.open(filepath) resp = "" diff --git a/document_loaders/ocr.py b/document_loaders/ocr.py new file mode 100644 index 00000000..2b66dd35 --- /dev/null +++ b/document_loaders/ocr.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + try: + from rapidocr_paddle import RapidOCR + except ImportError: + from rapidocr_onnxruntime import RapidOCR + + +def get_ocr(use_cuda: bool = True) -> "RapidOCR": + try: + from rapidocr_paddle import RapidOCR + ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda) + except ImportError: + from rapidocr_onnxruntime import RapidOCR + ocr = RapidOCR() + return ocr diff --git a/img/qr_code_71.jpg b/img/qr_code_71.jpg deleted file mode 100644 index 78828f7b..00000000 Binary files a/img/qr_code_71.jpg and /dev/null differ diff --git a/img/qr_code_72.jpg b/img/qr_code_72.jpg deleted file mode 100644 index 10a504b0..00000000 Binary files a/img/qr_code_72.jpg and /dev/null differ diff --git a/img/qr_code_73.jpg b/img/qr_code_73.jpg deleted file mode 100644 index 3e5ec45a..00000000 Binary files a/img/qr_code_73.jpg and /dev/null differ diff --git a/img/qr_code_74.jpg b/img/qr_code_74.jpg deleted file mode 100644 index 1f665f65..00000000 Binary files a/img/qr_code_74.jpg and /dev/null differ diff --git a/img/qr_code_76.jpg b/img/qr_code_76.jpg new file mode 100644 index 00000000..d952dc19 Binary files /dev/null and b/img/qr_code_76.jpg differ diff --git a/knowledge_base/samples/content/wiki b/knowledge_base/samples/content/wiki index f789e5dd..9a3fa7a7 160000 --- a/knowledge_base/samples/content/wiki +++ b/knowledge_base/samples/content/wiki @@ -1 +1 @@ -Subproject commit f789e5dde10f91136012f3470c020c8d34572436 +Subproject commit 9a3fa7a77f8748748b1c656fe8919ad5c4c63e3f diff --git a/requirements.txt b/requirements.txt index 036489e8..b5cd996d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,63 +1,75 @@ # API requirements -langchain>=0.0.334 +langchain==0.0.344 langchain-experimental>=0.0.42 -fschat[model_worker]>=0.2.33 +pydantic==1.10.13 +fschat>=0.2.33 xformers>=0.0.22.post7 -openai~=0.28.1 +openai>=1.3.7 sentence_transformers transformers>=4.35.2 -torch==2.1.0 ##on win, install the cuda version manually if you want use gpu -torchvision #on win, install the cuda version manually if you want use gpu -torchaudio #on win, install the cuda version manually if you want use gpu +torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/ +torchvision #on Windows system, install the cuda version manually from https://pytorch.org/ +torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/ fastapi>=0.104 nltk>=3.8.1 -uvicorn~=0.23.1 +uvicorn>=0.24.0.post1 starlette~=0.27.0 -pydantic<2 -unstructured[all-docs]>=0.11.0 +unstructured[all-docs]==0.11.0 python-magic-bin; sys_platform == 'win32' SQLAlchemy==2.0.19 -faiss-cpu -accelerate -spacy +faiss-cpu # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus +accelerate>=0.24.1 +spacy>=3.7.2 PyMuPDF rapidocr_onnxruntime - -requests -pathlib -pytest -numexpr -strsimpy -markdownify -tiktoken -tqdm +requests>=2.31.0 +pathlib>=1.0.1 +pytest>=7.4.3 +numexpr>=2.8.7 +strsimpy>=0.2.1 +markdownify>=0.11.6 +tiktoken>=0.5.1 +tqdm>=4.66.1 websockets numpy~=1.24.4 pandas~=2.0.3 -einops +einops>=0.7.0 transformers_stream_generator==0.0.4 - vllm==0.2.2; sys_platform == "linux" -# online api libs dependencies +# optional document loaders +# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files +# jq # for .json and .jsonl files. suggest `conda install jq` on windows +# html2text # for .enex files +# beautifulsoup4 # for .mhtml files +# pysrt # for .srt files + +# Online api libs dependencies # zhipuai>=1.0.7 # dashscope>=1.10.0 # qianfan>=0.2.0 # volcengine>=1.0.106 - -# uncomment libs if you want to use corresponding vector store -# pymilvus==2.1.3 # requires milvus==2.1.3 +# pymilvus>=2.3.3 # psycopg2 -# pgvector +# pgvector>=0.2.4 + +# Agent and Search Tools + +arxiv>=2.0.0 +youtube-search>=2.1.2 +duckduckgo-search>=3.9.9 +metaphor-python>=0.1.23 # WebUI requirements -streamlit~=1.28.2 # # on win, make sure write its path in environment variable +streamlit>=1.29.0 # do remember to add streamlit to environment variables if you use windows streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.2.3 streamlit-chatbox>=1.1.11 +streamlit-modal>=0.1.0 streamlit-aggrid>=0.3.4.post3 -httpx[brotli,http2,socks]~=0.24.1 -watchdog +httpx[brotli,http2,socks]>=0.25.2 +watchdog>=3.0.0 + diff --git a/requirements_api.txt b/requirements_api.txt index 85b23072..ec1005f9 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,52 +1,65 @@ # API requirements -langchain>=0.0.334 +langchain==0.0.344 langchain-experimental>=0.0.42 -fschat[model_worker]>=0.2.33 +pydantic==1.10.13 +fschat>=0.2.33 xformers>=0.0.22.post7 -openai~=0.28.1 +openai>=1.3.7 sentence_transformers transformers>=4.35.2 -torch==2.1.0 -torchvision -torchaudio +torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/ +torchvision #on Windows system, install the cuda version manually from https://pytorch.org/ +torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/ fastapi>=0.104 nltk>=3.8.1 -uvicorn~=0.23.1 +uvicorn>=0.24.0.post1 starlette~=0.27.0 -pydantic<2 -unstructured[all-docs]>=0.11.0 +unstructured[all-docs]==0.11.0 python-magic-bin; sys_platform == 'win32' SQLAlchemy==2.0.19 faiss-cpu accelerate>=0.24.1 -spacy +spacy>=3.7.2 PyMuPDF rapidocr_onnxruntime - -requests -pathlib -pytest -numexpr -strsimpy -markdownify -tiktoken -tqdm +requests>=2.31.0 +pathlib>=1.0.1 +pytest>=7.4.3 +numexpr>=2.8.7 +strsimpy>=0.2.1 +markdownify>=0.11.6 +tiktoken>=0.5.1 +tqdm>=4.66.1 websockets numpy~=1.24.4 pandas~=2.0.3 -einops -transformers_stream_generator>=0.0.4 +einops>=0.7.0 +transformers_stream_generator==0.0.4 +vllm==0.2.2; sys_platform == "linux" +httpx[brotli,http2,socks]>=0.25.2 -vllm>=0.2.0; sys_platform == "linux" -# online api libs -zhipuai -dashscope>=1.10.0 # qwen -qianfan -# volcengine>=1.0.106 # fangzhou +# optional document loaders +# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files +# jq # for .json and .jsonl files. suggest `conda install jq` on windows +# html2text # for .enex files +# beautifulsoup4 # for .mhtml files +# pysrt # for .srt files -# uncomment libs if you want to use corresponding vector store -# pymilvus==2.1.3 # requires milvus==2.1.3 +# Online api libs dependencies + +# zhipuai>=1.0.7 +# dashscope>=1.10.0 +# qianfan>=0.2.0 +# volcengine>=1.0.106 +# pymilvus>=2.3.3 # psycopg2 -# pgvector +# pgvector>=0.2.4 + +# Agent and Search Tools + +arxiv>=2.0.0 +youtube-search>=2.1.2 +duckduckgo-search>=3.9.9 +metaphor-python>=0.1.23 \ No newline at end of file diff --git a/requirements_lite.txt b/requirements_lite.txt index 1cfcb81b..ad01376c 100644 --- a/requirements_lite.txt +++ b/requirements_lite.txt @@ -1,62 +1,70 @@ -langchain>=0.0.334 -fschat>=0.2.32 -openai -# sentence_transformers -# transformers>=4.35.0 -# torch>=2.0.1 -# torchvision -# torchaudio +langchain==0.0.344 +pydantic==1.10.13 +fschat>=0.2.33 +openai>=1.3.7 fastapi>=0.104.1 python-multipart nltk~=3.8.1 -uvicorn~=0.23.1 +uvicorn>=0.24.0.post1 starlette~=0.27.0 -pydantic~=1.10.11 -unstructured[docx,csv]>=0.10.4 # add pdf if need +unstructured[docx,csv]==0.11.0 # add pdf if need python-magic-bin; sys_platform == 'win32' SQLAlchemy==2.0.19 +numexpr>=2.8.7 +strsimpy>=0.2.1 + faiss-cpu -# accelerate -# spacy +# accelerate>=0.24.1 +# spacy>=3.7.2 # PyMuPDF==1.22.5 # install if need pdf # rapidocr_onnxruntime>=1.3.2 # install if need pdf +# optional document loaders +# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files +# jq # for .json and .jsonl files. suggest `conda install jq` on windows +# html2text # for .enex files +# beautifulsoup4 # for .mhtml files +# pysrt # for .srt files + requests pathlib pytest # scikit-learn # numexpr -# vllm==0.1.7; sys_platform == "linux" +# vllm==0.2.2; sys_platform == "linux" # online api libs -zhipuai -dashscope>=1.10.0 # qwen -# qianfan + +zhipuai>=1.0.7 # zhipu +# dashscope>=1.10.0 # qwen # volcengine>=1.0.106 # fangzhou # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 # psycopg2 -# pgvector +# pgvector>=0.2.4 numpy~=1.24.4 pandas~=2.0.3 -streamlit~=1.28.1 +streamlit>=1.29.0 streamlit-option-menu>=0.3.6 -streamlit-antd-components>=0.1.11 -streamlit-chatbox==1.1.11 +streamlit-antd-components>=0.2.3 +streamlit-chatbox>=1.1.11 +streamlit-modal>=0.1.0 streamlit-aggrid>=0.3.4.post3 -httpx~=0.24.1 -watchdog -tqdm +httpx[brotli,http2,socks]>=0.25.2 +watchdog>=3.0.0 +tqdm>=4.66.1 websockets +einops>=0.7.0 + # tiktoken -einops -# scipy +# scipy>=1.11.4 # transformers_stream_generator==0.0.4 -# search engine libs -duckduckgo-search -metaphor-python -strsimpy -markdownify +# Agent and Search Tools + +arxiv>=2.0.0 +youtube-search>=2.1.2 +duckduckgo-search>=3.9.9 +metaphor-python>=0.1.23 \ No newline at end of file diff --git a/requirements_webui.txt b/requirements_webui.txt index 0fe3dc8e..3c16c32e 100644 --- a/requirements_webui.txt +++ b/requirements_webui.txt @@ -1,10 +1,10 @@ # WebUI requirements -streamlit~=1.28.2 +streamlit>=1.29.0 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.2.3 streamlit-chatbox>=1.1.11 +streamlit-modal>=0.1.0 streamlit-aggrid>=0.3.4.post3 -httpx[brotli,http2,socks]~=0.24.1 -watchdog - +httpx[brotli,http2,socks]>=0.25.2 +watchdog>=3.0.0 diff --git a/server/agent/tools/search_knowledgebase_complex.py b/server/agent/tools/search_knowledgebase_complex.py index 0b7884bd..af4d9116 100644 --- a/server/agent/tools/search_knowledgebase_complex.py +++ b/server/agent/tools/search_knowledgebase_complex.py @@ -170,7 +170,6 @@ class LLMKnowledgeChain(LLMChain): queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines] except: queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines] - print(queries) run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose) output = self._evaluate_expression(queries) run_manager.on_text("\nAnswer: ", verbose=self.verbose) diff --git a/server/api.py b/server/api.py index 5fc80c46..7444d4bf 100644 --- a/server/api.py +++ b/server/api.py @@ -81,6 +81,8 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): # 知识库相关接口 mount_knowledge_routes(app) + # 摘要相关接口 + mount_filename_summary_routes(app) # LLM模型相关接口 app.post("/llm_model/list_running_models", @@ -230,6 +232,26 @@ def mount_knowledge_routes(app: FastAPI): )(upload_temp_docs) +def mount_filename_summary_routes(app: FastAPI): + from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store, + summary_doc_ids_to_vector_store) + + app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store", + tags=["Knowledge kb_summary_api Management"], + summary="单个知识库根据文件名称摘要" + )(summary_file_to_vector_store) + app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store", + tags=["Knowledge kb_summary_api Management"], + summary="单个知识库根据doc_ids摘要", + response_model=BaseResponse, + )(summary_doc_ids_to_vector_store) + app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store", + tags=["Knowledge kb_summary_api Management"], + summary="重建单个知识库文件摘要" + )(recreate_summary_vector_store) + + + def run_api(host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): uvicorn.run(app, diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 8ae38db6..51d4d1f6 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -43,7 +43,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = CustomAsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/chat.py b/server/chat/chat.py index acf3ec0c..bac82f8e 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -19,6 +19,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), conversation_id: str = Body("", description="对话框ID"), + history_len: int = Body(-1, description="从数据库中取历史消息的数量"), history: Union[int, List[History]] = Body([], description="历史对话,设为一个整数可以从数据库中读取历史消息", examples=[[ @@ -34,18 +35,20 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): async def chat_iterator() -> AsyncIterable[str]: - nonlocal history + nonlocal history, max_tokens callback = AsyncIteratorCallbackHandler() callbacks = [callback] memory = None - message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) # 负责保存llm response到message db + message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, chat_type="llm_chat", query=query) callbacks.append(conversation_callback) + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None model = get_ChatOpenAI( model_name=model_name, @@ -54,18 +57,24 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callbacks=callbacks, ) - if not conversation_id: + if history: # 优先使用前端传入的历史消息 history = [History.from_data(h) for h in history] prompt_template = get_prompt_template("llm_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) - else: + elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息 # 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量 prompt = get_prompt_template("llm_chat", "with_history") chat_prompt = PromptTemplate.from_template(prompt) # 根据conversation_id 获取message 列表进而拼凑 memory - memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model) + memory = ConversationBufferDBMemory(conversation_id=conversation_id, + llm=model, + message_limit=history_len) + else: + prompt_template = get_prompt_template("llm_chat", prompt_name) + input_msg = History(role="user", content=prompt_template).to_msg_template(False) + chat_prompt = ChatPromptTemplate.from_messages([input_msg]) chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory) diff --git a/server/chat/completion.py b/server/chat/completion.py index ee5e2d12..6e5827dd 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -6,7 +6,8 @@ from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable, Optional import asyncio -from langchain.prompts.chat import PromptTemplate +from langchain.prompts import PromptTemplate + from server.utils import get_prompt_template @@ -27,7 +28,11 @@ async def completion(query: str = Body(..., description="用户输入", examples prompt_name: str = prompt_name, echo: bool = echo, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_OpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index ea3475a0..a58cb29b 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -113,7 +113,11 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= history = [History.from_data(h) for h in history] async def knowledge_base_chat_iterator() -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, @@ -121,7 +125,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= callbacks=[callback], ) embed_func = EmbeddingsFunAdapter() - embeddings = embed_func.embed_query(query) + embeddings = await embed_func.aembed_query(query) with memo_faiss_pool.acquire(knowledge_id) as vs: docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) docs = [x[0] for x in docs] diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 0ea99a6e..a3ab68b1 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,5 +1,6 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse +from fastapi.concurrency import run_in_threadpool from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template @@ -10,55 +11,74 @@ import asyncio from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History from server.knowledge_base.kb_service.base import KBServiceFactory -from server.knowledge_base.utils import get_doc_path import json -from pathlib import Path from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2), - history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), - prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - request: Request = None, - ): + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body( + SCORE_THRESHOLD, + description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", + ge=0, + le=2 + ), + history: List[History] = Body( + [], + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body( + None, + description="限制LLM生成Token数量,默认None代表模型最大值" + ), + prompt_name: str = Body( + "default", + description="使用的prompt模板名称(在configs/prompt_config.py中配置)" + ), + request: Request = None, + ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") history = [History.from_data(h) for h in history] - async def knowledge_base_chat_iterator(query: str, - top_k: int, - history: Optional[List[History]], - model_name: str = LLM_MODELS[0], - prompt_name: str = prompt_name, - ) -> AsyncIterable[str]: + async def knowledge_base_chat_iterator( + query: str, + top_k: int, + history: Optional[List[History]], + model_name: str = LLM_MODELS[0], + prompt_name: str = prompt_name, + ) -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=[callback], ) - docs = search_docs(query, knowledge_base_name, top_k, score_threshold) + docs = await run_in_threadpool(search_docs, + query=query, + knowledge_base_name=knowledge_base_name, + top_k=top_k, + score_threshold=score_threshold) context = "\n".join([doc.page_content for doc in docs]) - if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板 - prompt_template = get_prompt_template("knowledge_base_chat", "Empty") + if len(docs) == 0: # 如果没有找到相关文档,使用empty模板 + prompt_template = get_prompt_template("knowledge_base_chat", "empty") else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) @@ -76,14 +96,14 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", source_documents = [] for inum, doc in enumerate(docs): filename = doc.metadata.get("source") - parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) + parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename}) base_url = request.base_url url = f"{base_url}knowledge_base/download_doc?" + parameters text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" source_documents.append(text) - if len(source_documents) == 0: # 没有找到相关文档 - source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""") + if len(source_documents) == 0: # 没有找到相关文档 + source_documents.append(f"未找到相关文档,该回答为大模型自身能力解答!") if stream: async for token in callback.aiter(): @@ -104,4 +124,4 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", history=history, model_name=model_name, prompt_name=prompt_name), - media_type="text/event-stream") \ No newline at end of file + media_type="text/event-stream") diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 8325b4d9..5b77e993 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -147,7 +147,11 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, diff --git a/server/db/models/knowledge_metadata_model.py b/server/db/models/knowledge_metadata_model.py new file mode 100644 index 00000000..03f42009 --- /dev/null +++ b/server/db/models/knowledge_metadata_model.py @@ -0,0 +1,28 @@ +from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func + +from server.db.base import Base + + +class SummaryChunkModel(Base): + """ + chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段, + 数据来源: + 用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中 + 程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中 + 后续任务: + 矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids) + 语义关联: 通过用户输入的描述,自动切分的总结文本,计算 + 语义相似度 + + """ + __tablename__ = 'summary_chunk' + id = Column(Integer, primary_key=True, autoincrement=True, comment='ID') + kb_name = Column(String(50), comment='知识库名称') + summary_context = Column(String(255), comment='总结文本') + summary_id = Column(String(255), comment='总结矢量id') + doc_ids = Column(String(1024), comment="向量库id关联列表") + meta_data = Column(JSON, default={}) + + def __repr__(self): + return (f"") diff --git a/server/db/repository/knowledge_metadata_repository.py b/server/db/repository/knowledge_metadata_repository.py new file mode 100644 index 00000000..4158e703 --- /dev/null +++ b/server/db/repository/knowledge_metadata_repository.py @@ -0,0 +1,66 @@ +from server.db.models.knowledge_metadata_model import SummaryChunkModel +from server.db.session import with_session +from typing import List, Dict + + +@with_session +def list_summary_from_db(session, + kb_name: str, + metadata: Dict = {}, + ) -> List[Dict]: + ''' + 列出某知识库chunk summary。 + 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] + ''' + docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name) + + for k, v in metadata.items(): + docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v)) + + return [{"id": x.id, + "summary_context": x.summary_context, + "summary_id": x.summary_id, + "doc_ids": x.doc_ids, + "metadata": x.metadata} for x in docs.all()] + + +@with_session +def delete_summary_from_db(session, + kb_name: str + ) -> List[Dict]: + ''' + 删除知识库chunk summary,并返回被删除的Dchunk summary。 + 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] + ''' + docs = list_summary_from_db(kb_name=kb_name) + query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name) + query.delete() + session.commit() + return docs + + +@with_session +def add_summary_to_db(session, + kb_name: str, + summary_infos: List[Dict]): + ''' + 将总结信息添加到数据库。 + summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...] + ''' + for summary in summary_infos: + obj = SummaryChunkModel( + kb_name=kb_name, + summary_context=summary["summary_context"], + summary_id=summary["summary_id"], + doc_ids=summary["doc_ids"], + meta_data=summary["metadata"], + ) + session.add(obj) + + session.commit() + return True + + +@with_session +def count_summary_from_db(session, kb_name: str) -> int: + return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count() diff --git a/server/embeddings_api.py b/server/embeddings_api.py index 80cd289f..93555a35 100644 --- a/server/embeddings_api.py +++ b/server/embeddings_api.py @@ -3,6 +3,7 @@ from configs import EMBEDDING_MODEL, logger from server.model_workers.base import ApiEmbeddingsParams from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models from fastapi import Body +from fastapi.concurrency import run_in_threadpool from typing import Dict, List @@ -39,6 +40,32 @@ def embed_texts( logger.error(e) return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") + +async def aembed_texts( + texts: List[str], + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, +) -> BaseResponse: + ''' + 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) + ''' + try: + if embed_model in list_embed_models(): # 使用本地Embeddings模型 + from server.utils import load_local_embeddings + + embeddings = load_local_embeddings(model=embed_model) + return BaseResponse(data=await embeddings.aembed_documents(texts)) + + if embed_model in list_online_embed_models(): # 使用在线API + return await run_in_threadpool(embed_texts, + texts=texts, + embed_model=embed_model, + to_query=to_query) + except Exception as e: + logger.error(e) + return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") + + def embed_texts_endpoint( texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"), diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 6823536c..68b54943 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -45,7 +45,7 @@ class _FaissPool(CachePool): # create an empty vector store embeddings = EmbeddingsFunAdapter(embed_model) doc = Document(page_content="init", metadata={}) - vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) + vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") ids = list(vector_store.docstore._dict.keys()) vector_store.delete(ids) return vector_store @@ -82,7 +82,7 @@ class KBFaissPool(_FaissPool): if os.path.isfile(os.path.join(vs_path, "index.faiss")): embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model) - vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") elif create: # create an empty vector store if not os.path.exists(vs_path): diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 0af8de5e..1d86d306 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -26,8 +26,8 @@ from server.knowledge_base.utils import ( from typing import List, Union, Dict, Optional -from server.embeddings_api import embed_texts -from server.embeddings_api import embed_documents +from server.embeddings_api import embed_texts, aembed_texts, embed_documents +from server.knowledge_base.model.kb_document_model import DocumentWithVSId def normalize(embeddings: List[List[float]]) -> np.ndarray: @@ -183,12 +183,22 @@ class KBService(ABC): def get_doc_by_ids(self, ids: List[str]) -> List[Document]: return [] - def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]: + def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]: ''' 通过file_name或metadata检索Document ''' doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata) - docs = self.get_doc_by_ids([x["id"] for x in doc_infos]) + docs = [] + for x in doc_infos: + doc_info_s = self.get_doc_by_ids([x["id"]]) + if doc_info_s is not None and doc_info_s != []: + # 处理非空的情况 + doc_with_id = DocumentWithVSId(**doc_info_s[0].dict(), id=x["id"]) + docs.append(doc_with_id) + else: + # 处理空的情况 + # 可以选择跳过当前循环迭代或执行其他操作 + pass return docs @abstractmethod @@ -394,12 +404,16 @@ class EmbeddingsFunAdapter(Embeddings): normalized_query_embed = normalize(query_embed_2d) return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 - # TODO: 暂不支持异步 - # async def aembed_documents(self, texts: List[str]) -> List[List[float]]: - # return normalize(await self.embeddings.aembed_documents(texts)) + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data + return normalize(embeddings).tolist() - # async def aembed_query(self, text: str) -> List[float]: - # return normalize(await self.embeddings.aembed_query(text)) + async def aembed_query(self, text: str) -> List[float]: + embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data + query_embed = embeddings[0] + query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 + normalized_query_embed = normalize(query_embed_2d) + return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 def score_threshold_process(score_threshold, k, docs): diff --git a/server/knowledge_base/kb_summary/__init__.py b/server/knowledge_base/kb_summary/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/knowledge_base/kb_summary/base.py b/server/knowledge_base/kb_summary/base.py new file mode 100644 index 00000000..00dcea6f --- /dev/null +++ b/server/knowledge_base/kb_summary/base.py @@ -0,0 +1,79 @@ +from typing import List + +from configs import ( + EMBEDDING_MODEL, + KB_ROOT_PATH) + +from abc import ABC, abstractmethod +from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss +import os +import shutil +from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db + +from langchain.docstore.document import Document + + +# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加 +class KBSummaryService(ABC): + kb_name: str + embed_model: str + vs_path: str + kb_path: str + + def __init__(self, + knowledge_base_name: str, + embed_model: str = EMBEDDING_MODEL + ): + self.kb_name = knowledge_base_name + self.embed_model = embed_model + + self.kb_path = self.get_kb_path() + self.vs_path = self.get_vs_path() + + if not os.path.exists(self.vs_path): + os.makedirs(self.vs_path) + + + def get_vs_path(self): + return os.path.join(self.get_kb_path(), "summary_vector_store") + + def get_kb_path(self): + return os.path.join(KB_ROOT_PATH, self.kb_name) + + def load_vector_store(self) -> ThreadSafeFaiss: + return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, + vector_name="summary_vector_store", + embed_model=self.embed_model, + create=True) + + def add_kb_summary(self, summary_combine_docs: List[Document]): + with self.load_vector_store().acquire() as vs: + ids = vs.add_documents(documents=summary_combine_docs) + vs.save_local(self.vs_path) + + summary_infos = [{"summary_context": doc.page_content, + "summary_id": id, + "doc_ids": doc.metadata.get('doc_ids'), + "metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)] + status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos) + return status + + def create_kb_summary(self): + """ + 创建知识库chunk summary + :return: + """ + + if not os.path.exists(self.vs_path): + os.makedirs(self.vs_path) + + def drop_kb_summary(self): + """ + 删除知识库chunk summary + :param kb_name: + :return: + """ + with kb_faiss_pool.atomic: + kb_faiss_pool.pop(self.kb_name) + shutil.rmtree(self.vs_path) + delete_summary_from_db(kb_name=self.kb_name) diff --git a/server/knowledge_base/kb_summary/summary_chunk.py b/server/knowledge_base/kb_summary/summary_chunk.py new file mode 100644 index 00000000..0b88f233 --- /dev/null +++ b/server/knowledge_base/kb_summary/summary_chunk.py @@ -0,0 +1,247 @@ +from typing import List, Optional + +from langchain.schema.language_model import BaseLanguageModel + +from server.knowledge_base.model.kb_document_model import DocumentWithVSId +from configs import (logger) +from langchain.chains import StuffDocumentsChain, LLMChain +from langchain.prompts import PromptTemplate + +from langchain.docstore.document import Document +from langchain.output_parsers.regex import RegexParser +from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain + +import sys +import asyncio + + +class SummaryAdapter: + _OVERLAP_SIZE: int + token_max: int + _separator: str = "\n\n" + chain: MapReduceDocumentsChain + + def __init__(self, overlap_size: int, token_max: int, + chain: MapReduceDocumentsChain): + self._OVERLAP_SIZE = overlap_size + self.chain = chain + self.token_max = token_max + + @classmethod + def form_summary(cls, + llm: BaseLanguageModel, + reduce_llm: BaseLanguageModel, + overlap_size: int, + token_max: int = 1300): + """ + 获取实例 + :param reduce_llm: 用于合并摘要的llm + :param llm: 用于生成摘要的llm + :param overlap_size: 重叠部分大小 + :param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错 + :return: + """ + + # This controls how each document will be formatted. Specifically, + document_prompt = PromptTemplate( + input_variables=["page_content"], + template="{page_content}" + ) + + # The prompt here should take as an input variable the + # `document_variable_name` + prompt_template = ( + "根据文本执行任务。以下任务信息" + "{task_briefing}" + "文本内容如下: " + "\r\n" + "{context}" + ) + prompt = PromptTemplate( + template=prompt_template, + input_variables=["task_briefing", "context"] + ) + llm_chain = LLMChain(llm=llm, prompt=prompt) + # We now define how to combine these summaries + reduce_prompt = PromptTemplate.from_template( + "Combine these summaries: {context}" + ) + reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt) + + document_variable_name = "context" + combine_documents_chain = StuffDocumentsChain( + llm_chain=reduce_llm_chain, + document_prompt=document_prompt, + document_variable_name=document_variable_name + ) + reduce_documents_chain = ReduceDocumentsChain( + token_max=token_max, + combine_documents_chain=combine_documents_chain, + ) + chain = MapReduceDocumentsChain( + llm_chain=llm_chain, + document_variable_name=document_variable_name, + reduce_documents_chain=reduce_documents_chain, + # 返回中间步骤 + return_intermediate_steps=True + ) + return cls(overlap_size=overlap_size, + chain=chain, + token_max=token_max) + + def summarize(self, + file_description: str, + docs: List[DocumentWithVSId] = [] + ) -> List[Document]: + + if sys.version_info < (3, 10): + loop = asyncio.get_event_loop() + else: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + + asyncio.set_event_loop(loop) + # 同步调用协程代码 + return loop.run_until_complete(self.asummarize(file_description=file_description, + docs=docs)) + + async def asummarize(self, + file_description: str, + docs: List[DocumentWithVSId] = []) -> List[Document]: + + logger.info("start summary") + # TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题 + # merge_docs = self._drop_overlap(docs) + # # 将merge_docs中的句子合并成一个文档 + # text = self._join_docs(merge_docs) + # 根据段落于句子的分隔符,将文档分成chunk,每个chunk长度小于token_max长度 + + """ + 这个过程分成两个部分: + 1. 对每个文档进行处理,得到每个文档的摘要 + map_results = self.llm_chain.apply( + # FYI - this is parallelized and so it is fast. + [{self.document_variable_name: d.page_content, **kwargs} for d in docs], + callbacks=callbacks, + ) + 2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤 + result, extra_return_dict = self.reduce_documents_chain.combine_docs( + result_docs, token_max=token_max, callbacks=callbacks, **kwargs + ) + """ + summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs, + task_briefing="描述不同方法之间的接近度和相似性," + "以帮助读者理解它们之间的关系。") + print(summary_combine) + print(summary_intermediate_steps) + + # if len(summary_combine) == 0: + # # 为空重新生成,数量减半 + # result_docs = [ + # Document(page_content=question_result_key, metadata=docs[i].metadata) + # # This uses metadata from the docs, and the textual results from `results` + # for i, question_result_key in enumerate( + # summary_intermediate_steps["intermediate_steps"][ + # :len(summary_intermediate_steps["intermediate_steps"]) // 2 + # ]) + # ] + # summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs( + # result_docs, token_max=self.token_max + # ) + logger.info("end summary") + doc_ids = ",".join([doc.id for doc in docs]) + _metadata = { + "file_description": file_description, + "summary_intermediate_steps": summary_intermediate_steps, + "doc_ids": doc_ids + } + summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata) + + return [summary_combine_doc] + + def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]: + """ + # 将文档中page_content句子叠加的部分去掉 + :param docs: + :param separator: + :return: + """ + merge_docs = [] + + pre_doc = None + for doc in docs: + # 第一个文档直接添加 + if len(merge_docs) == 0: + pre_doc = doc.page_content + merge_docs.append(doc.page_content) + continue + + # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 + # 迭代递减pre_doc的长度,每次迭代删除前面的字符, + # 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator) + for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1): + # 每次迭代删除前面的字符 + pre_doc = pre_doc[1:] + if doc.page_content[:len(pre_doc)] == pre_doc: + # 删除下一个开头重叠的部分 + merge_docs.append(doc.page_content[len(pre_doc):]) + break + + pre_doc = doc.page_content + + return merge_docs + + def _join_docs(self, docs: List[str]) -> Optional[str]: + text = self._separator.join(docs) + text = text.strip() + if text == "": + return None + else: + return text + + +if __name__ == '__main__': + + docs = [ + + '梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的', + + '梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象', + + '使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各', + '值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全' + ] + _OVERLAP_SIZE = 1 + separator: str = "\n\n" + merge_docs = [] + # 将文档中page_content句子叠加的部分去掉, + # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 + pre_doc = None + for doc in docs: + # 第一个文档直接添加 + if len(merge_docs) == 0: + pre_doc = doc + merge_docs.append(doc) + continue + + # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 + # 迭代递减pre_doc的长度,每次迭代删除前面的字符, + # 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator) + for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1): + # 每次迭代删除前面的字符 + pre_doc = pre_doc[1:] + if doc[:len(pre_doc)] == pre_doc: + # 删除下一个开头重叠的部分 + page_content = doc[len(pre_doc):] + merge_docs.append(page_content) + + pre_doc = doc + break + + # 将merge_docs中的句子合并成一个文档 + text = separator.join(merge_docs) + text = text.strip() + + print(text) diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py new file mode 100644 index 00000000..aac4de78 --- /dev/null +++ b/server/knowledge_base/kb_summary_api.py @@ -0,0 +1,220 @@ +from fastapi import Body +from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, + OVERLAP_SIZE, + logger, log_verbose, ) +from server.knowledge_base.utils import (list_files_from_folder) +from fastapi.responses import StreamingResponse +import json +from server.knowledge_base.kb_service.base import KBServiceFactory +from typing import List, Optional +from server.knowledge_base.kb_summary.base import KBSummaryService +from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter +from server.utils import wrap_done, get_ChatOpenAI, BaseResponse +from configs import LLM_MODELS, TEMPERATURE +from server.knowledge_base.model.kb_document_model import DocumentWithVSId + +def recreate_summary_vector_store( + knowledge_base_name: str = Body(..., examples=["samples"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), + file_description: str = Body(''), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), +): + """ + 重建单个知识库文件摘要 + :param max_tokens: + :param model_name: + :param temperature: + :param file_description: + :param knowledge_base_name: + :param allow_empty_kb: + :param vs_type: + :param embed_model: + :return: + """ + + def output(): + + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists() and not allow_empty_kb: + yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} + else: + # 重新创建知识库 + kb_summary = KBSummaryService(knowledge_base_name, embed_model) + kb_summary.drop_kb_summary() + kb_summary.create_kb_summary() + + llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + reduce_llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + # 文本摘要适配器 + summary = SummaryAdapter.form_summary(llm=llm, + reduce_llm=reduce_llm, + overlap_size=OVERLAP_SIZE) + files = list_files_from_folder(knowledge_base_name) + + i = 0 + for i, file_name in enumerate(files): + + doc_infos = kb.list_docs(file_name=file_name) + docs = summary.summarize(file_description=file_description, + docs=doc_infos) + + status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) + if status_kb_summary: + logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成") + yield json.dumps({ + "code": 200, + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), + "finished": i + 1, + "doc": file_name, + }, ensure_ascii=False) + else: + + msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" + logger.error(msg) + yield json.dumps({ + "code": 500, + "msg": msg, + }) + i += 1 + + return StreamingResponse(output(), media_type="text/event-stream") + + +def summary_file_to_vector_store( + knowledge_base_name: str = Body(..., examples=["samples"]), + file_name: str = Body(..., examples=["test.pdf"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), + file_description: str = Body(''), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), +): + """ + 单个知识库根据文件名称摘要 + :param model_name: + :param max_tokens: + :param temperature: + :param file_description: + :param file_name: + :param knowledge_base_name: + :param allow_empty_kb: + :param vs_type: + :param embed_model: + :return: + """ + + def output(): + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists() and not allow_empty_kb: + yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} + else: + # 重新创建知识库 + kb_summary = KBSummaryService(knowledge_base_name, embed_model) + kb_summary.create_kb_summary() + + llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + reduce_llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + # 文本摘要适配器 + summary = SummaryAdapter.form_summary(llm=llm, + reduce_llm=reduce_llm, + overlap_size=OVERLAP_SIZE) + + doc_infos = kb.list_docs(file_name=file_name) + docs = summary.summarize(file_description=file_description, + docs=doc_infos) + + status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) + if status_kb_summary: + logger.info(f" {file_name} 总结完成") + yield json.dumps({ + "code": 200, + "msg": f"{file_name} 总结完成", + "doc": file_name, + }, ensure_ascii=False) + else: + + msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" + logger.error(msg) + yield json.dumps({ + "code": 500, + "msg": msg, + }) + + return StreamingResponse(output(), media_type="text/event-stream") + + +def summary_doc_ids_to_vector_store( + knowledge_base_name: str = Body(..., examples=["samples"]), + doc_ids: List = Body([], examples=[["uuid"]]), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), + file_description: str = Body(''), + model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), +) -> BaseResponse: + """ + 单个知识库根据doc_ids摘要 + :param knowledge_base_name: + :param doc_ids: + :param model_name: + :param max_tokens: + :param temperature: + :param file_description: + :param vs_type: + :param embed_model: + :return: + """ + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists(): + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={}) + else: + llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + reduce_llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + # 文本摘要适配器 + summary = SummaryAdapter.form_summary(llm=llm, + reduce_llm=reduce_llm, + overlap_size=OVERLAP_SIZE) + + doc_infos = kb.get_doc_by_ids(ids=doc_ids) + # doc_infos转换成DocumentWithVSId包装的对象 + doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)] + + docs = summary.summarize(file_description=file_description, + docs=doc_info_with_ids) + + # 将docs转换成dict + resp_summarize = [{**doc.dict()} for doc in docs] + + return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize}) diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index b3601476..bde6e1f8 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -12,6 +12,8 @@ from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.models.conversation_model import ConversationModel from server.db.models.message_model import MessageModel from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported +from server.db.repository.knowledge_metadata_repository import add_summary_to_db + from server.db.base import Base, engine from server.db.session import session_scope import os diff --git a/server/knowledge_base/model/kb_document_model.py b/server/knowledge_base/model/kb_document_model.py new file mode 100644 index 00000000..a5d2c6ab --- /dev/null +++ b/server/knowledge_base/model/kb_document_model.py @@ -0,0 +1,10 @@ + +from langchain.docstore.document import Document + + +class DocumentWithVSId(Document): + """ + 矢量化后的文档 + """ + id: str = None + diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index d5587ea5..41c3ad5a 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -71,7 +71,7 @@ def list_files_from_folder(kb_name: str): for target_entry in target_it: process_entry(target_entry) elif entry.is_file(): - result.append(entry.path) + result.append(os.path.relpath(entry.path, doc_path)) elif entry.is_dir(): with os.scandir(entry.path) as it: for sub_entry in it: @@ -85,6 +85,7 @@ def list_files_from_folder(kb_name: str): LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], + "MHTMLLoader": ['.mhtml'], "UnstructuredMarkdownLoader": ['.md'], "JSONLoader": [".json"], "JSONLinesLoader": [".jsonl"], @@ -106,6 +107,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], "UnstructuredWordDocumentLoader": ['.docx', 'doc'], "UnstructuredXMLLoader": ['.xml'], "UnstructuredPowerPointLoader": ['.ppt', '.pptx'], + "EverNoteLoader": ['.enex'], "UnstructuredFileLoader": ['.txt'], } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] @@ -311,6 +313,9 @@ class KnowledgeFile: else: docs = text_splitter.split_documents(docs) + if not docs: + return [] + print(f"文档切分示例:{docs[0]}") if zh_title_enhance: docs = func_zh_title_enhance(docs) diff --git a/server/model_workers/azure.py b/server/model_workers/azure.py index c6b4cbba..70959325 100644 --- a/server/model_workers/azure.py +++ b/server/model_workers/azure.py @@ -1,4 +1,5 @@ import sys +import os from fastchat.conversation import Conversation from server.model_workers.base import * from server.utils import get_httpx_client @@ -19,16 +20,16 @@ class AzureWorker(ApiModelWorker): **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384 super().__init__(**kwargs) self.version = version def do_chat(self, params: ApiChatParams) -> Dict: params.load_config(self.model_names[0]) + data = dict( messages=params.messages, temperature=params.temperature, - max_tokens=params.max_tokens, + max_tokens=params.max_tokens if params.max_tokens else None, stream=True, ) url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}" @@ -47,6 +48,7 @@ class AzureWorker(ApiModelWorker): with get_httpx_client() as client: with client.stream("POST", url, headers=headers, json=data) as response: + print(data) for line in response.iter_lines(): if not line.strip() or "[DONE]" in line: continue @@ -60,6 +62,7 @@ class AzureWorker(ApiModelWorker): "error_code": 0, "text": text } + print(text) else: self.logger.error(f"请求 Azure API 时发生错误:{resp}") @@ -91,4 +94,4 @@ if __name__ == "__main__": ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=21008) + uvicorn.run(app, port=21008) \ No newline at end of file diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 45e9a5f9..947782a8 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -119,7 +119,9 @@ class MiniMaxWorker(ApiModelWorker): with get_httpx_client() as client: result = [] i = 0 - for texts in params.texts[i:i+10]: + batch_size = 10 + while i < len(params.texts): + texts = params.texts[i:i+batch_size] data["texts"] = texts r = client.post(url, headers=headers, json=data).json() if embeddings := r.get("vectors"): @@ -137,7 +139,7 @@ class MiniMaxWorker(ApiModelWorker): } self.logger.error(f"请求 MiniMax API 时发生错误:{data}") return data - i += 10 + i += batch_size return {"code": 200, "data": embeddings} def get_embeddings(self, params): diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 95e200ac..2bcce94e 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -188,9 +188,11 @@ class QianFanWorker(ApiModelWorker): with get_httpx_client() as client: result = [] i = 0 - for texts in params.texts[i:i+10]: + batch_size = 10 + while i < len(params.texts): + texts = params.texts[i:i+batch_size] resp = client.post(url, json={"input": texts}).json() - if "error_cdoe" in resp: + if "error_code" in resp: data = { "code": resp["error_code"], "msg": resp["error_msg"], @@ -206,7 +208,7 @@ class QianFanWorker(ApiModelWorker): else: embeddings = [x["embedding"] for x in resp.get("data", [])] result += embeddings - i += 10 + i += batch_size return {"code": 200, "data": result} # TODO: qianfan支持续写模型 diff --git a/server/utils.py b/server/utils.py index c51c6e86..21b1baf0 100644 --- a/server/utils.py +++ b/server/utils.py @@ -40,11 +40,10 @@ def get_ChatOpenAI( verbose: bool = True, **kwargs: Any, ) -> ChatOpenAI: - ## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装 - config_models = list_config_llm_models() - - ## 非Langchain原生支持的模型,走Fschat封装 config = get_model_worker_config(model_name) + if model_name == "openai-api": + model_name = config.get("model_name") + model = ChatOpenAI( streaming=streaming, verbose=verbose, @@ -57,10 +56,8 @@ def get_ChatOpenAI( openai_proxy=config.get("openai_proxy"), **kwargs ) - return model - def get_OpenAI( model_name: str, temperature: float, @@ -71,67 +68,22 @@ def get_OpenAI( verbose: bool = True, **kwargs: Any, ) -> OpenAI: - ## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装 - config_models = list_config_llm_models() - if model_name in config_models.get("langchain", {}): - config = config_models["langchain"][model_name] - if model_name == "Azure-OpenAI": - model = AzureOpenAI( - streaming=streaming, - verbose=verbose, - callbacks=callbacks, - deployment_name=config.get("deployment_name"), - model_version=config.get("model_version"), - openai_api_type=config.get("openai_api_type"), - openai_api_base=config.get("api_base_url"), - openai_api_version=config.get("api_version"), - openai_api_key=config.get("api_key"), - openai_proxy=config.get("openai_proxy"), - temperature=temperature, - max_tokens=max_tokens, - echo=echo, - ) - - elif model_name == "OpenAI": - model = OpenAI( - streaming=streaming, - verbose=verbose, - callbacks=callbacks, - model_name=config.get("model_name"), - openai_api_base=config.get("api_base_url"), - openai_api_key=config.get("api_key"), - openai_proxy=config.get("openai_proxy"), - temperature=temperature, - max_tokens=max_tokens, - echo=echo, - ) - elif model_name == "Anthropic": - model = Anthropic( - streaming=streaming, - verbose=verbose, - callbacks=callbacks, - model_name=config.get("model_name"), - anthropic_api_key=config.get("api_key"), - echo=echo, - ) - ## TODO 支持其他的Langchain原生支持的模型 - else: - ## 非Langchain原生支持的模型,走Fschat封装 - config = get_model_worker_config(model_name) - model = OpenAI( - streaming=streaming, - verbose=verbose, - callbacks=callbacks, - openai_api_key=config.get("api_key", "EMPTY"), - openai_api_base=config.get("api_base_url", fschat_openai_api_address()), - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - openai_proxy=config.get("openai_proxy"), - echo=echo, - **kwargs - ) - + config = get_model_worker_config(model_name) + if model_name == "openai-api": + model_name = config.get("model_name") + model = OpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + openai_api_key=config.get("api_key", "EMPTY"), + openai_api_base=config.get("api_base_url", fschat_openai_api_address()), + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + openai_proxy=config.get("openai_proxy"), + echo=echo, + **kwargs + ) return model @@ -630,7 +582,7 @@ def get_httpx_client( for host in os.environ.get("no_proxy", "").split(","): if host := host.strip(): # default_proxies.update({host: None}) # Origin code - default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error + default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error # merge default proxies with user provided proxies if isinstance(proxies, str): @@ -714,7 +666,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]: from configs.basic_config import BASE_TEMP_DIR import tempfile - if id is not None: # 如果指定的临时目录已存在,直接返回 + if id is not None: # 如果指定的临时目录已存在,直接返回 path = os.path.join(BASE_TEMP_DIR, id) if os.path.isdir(path): return path, id diff --git a/startup.py b/startup.py index e610cda1..3bb508a0 100644 --- a/startup.py +++ b/startup.py @@ -36,6 +36,7 @@ from server.utils import (fschat_controller_address, fschat_model_worker_address fschat_openai_api_address, set_httpx_config, get_httpx_client, get_model_worker_config, get_all_model_worker_configs, MakeFastAPIOffline, FastAPI, llm_device, embedding_device) +from server.knowledge_base.migrate import create_tables import argparse from typing import Tuple, List, Dict from configs import VERSION @@ -866,6 +867,8 @@ async def start_main_server(): logger.info("Process status: %s", p) if __name__ == "__main__": + # 确保数据库表被创建 + create_tables() if sys.version_info < (3, 10): loop = asyncio.get_event_loop() diff --git a/tests/api/test_kb_summary_api.py b/tests/api/test_kb_summary_api.py new file mode 100644 index 00000000..d59c2036 --- /dev/null +++ b/tests/api/test_kb_summary_api.py @@ -0,0 +1,44 @@ +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from server.utils import api_address + +api_base_url = api_address() + +kb = "samples" +file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md" +doc_ids = [ + "357d580f-fdf7-495c-b58b-595a398284e8", + "c7338773-2e83-4671-b237-1ad20335b0f0", + "6da613d1-327d-466f-8c1a-b32e6f461f47" +] + + +def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"): + url = api_base_url + api + print("\n文件摘要:") + r = requests.post(url, json={"knowledge_base_name": kb, + "file_name": file_name + }, stream=True) + for chunk in r.iter_content(None): + data = json.loads(chunk) + assert isinstance(data, dict) + assert data["code"] == 200 + print(data["msg"]) + + +def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"): + url = api_base_url + api + print("\n文件摘要:") + r = requests.post(url, json={"knowledge_base_name": kb, + "doc_ids": doc_ids + }, stream=True) + for chunk in r.iter_content(None): + data = json.loads(chunk) + assert isinstance(data, dict) + assert data["code"] == 200 + print(data) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 3310fb3d..064c8a42 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -1,8 +1,11 @@ import streamlit as st from webui_pages.utils import * from streamlit_chatbox import * +from streamlit_modal import Modal from datetime import datetime import os +import re +import time from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL) from server.knowledge_base.utils import LOADER_DICT @@ -47,6 +50,53 @@ def upload_temp_docs(files, _api: ApiRequest) -> str: return _api.upload_temp_docs(files).get("data", {}).get("id") +def parse_command(text: str, modal: Modal) -> bool: + ''' + 检查用户是否输入了自定义命令,当前支持: + /new {session_name}。如果未提供名称,默认为“会话X” + /del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。 + /clear {session_name}。如果未提供名称,默认清除当前会话 + /help。查看命令帮助 + 返回值:输入的是命令返回True,否则返回False + ''' + if m := re.match(r"/([^\s]+)\s*(.*)", text): + cmd, name = m.groups() + name = name.strip() + conv_names = chat_box.get_chat_names() + if cmd == "help": + modal.open() + elif cmd == "new": + if not name: + i = 1 + while True: + name = f"会话{i}" + if name not in conv_names: + break + i += 1 + if name in st.session_state["conversation_ids"]: + st.error(f"该会话名称 “{name}” 已存在") + time.sleep(1) + else: + st.session_state["conversation_ids"][name] = uuid.uuid4().hex + st.session_state["cur_conv_name"] = name + elif cmd == "del": + name = name or st.session_state.get("cur_conv_name") + if len(conv_names) == 1: + st.error("这是最后一个会话,无法删除") + time.sleep(1) + elif not name or name not in st.session_state["conversation_ids"]: + st.error(f"无效的会话名称:“{name}”") + time.sleep(1) + else: + st.session_state["conversation_ids"].pop(name, None) + chat_box.del_chat_name(name) + st.session_state["cur_conv_name"] = "" + elif cmd == "clear": + chat_box.reset_history(name=name or None) + return True + return False + + def dialogue_page(api: ApiRequest, is_lite: bool = False): st.session_state.setdefault("conversation_ids", {}) st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex) @@ -60,25 +110,15 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): ) chat_box.init_session() + # 弹出自定义命令帮助信息 + modal = Modal("自定义命令", key="cmd_help", max_width="500") + if modal.is_open(): + with modal.container(): + cmds = [x for x in parse_command.__doc__.split("\n") if x.strip().startswith("/")] + st.write("\n\n".join(cmds)) + with st.sidebar: # 多会话 - cols = st.columns([3, 1]) - conv_name = cols[0].text_input("会话名称") - with cols[1]: - if st.button("添加"): - if not conv_name or conv_name in st.session_state["conversation_ids"]: - st.error("请指定有效的会话名称") - else: - st.session_state["conversation_ids"][conv_name] = uuid.uuid4().hex - st.session_state["cur_conv_name"] = conv_name - st.session_state["conv_name"] = "" - if st.button("删除"): - if not conv_name or conv_name not in st.session_state["conversation_ids"]: - st.error("请指定有效的会话名称") - else: - st.session_state["conversation_ids"].pop(conv_name, None) - st.session_state["cur_conv_name"] = "" - conv_names = list(st.session_state["conversation_ids"].keys()) index = 0 if st.session_state.get("cur_conv_name") in conv_names: @@ -236,7 +276,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): # Display chat messages from history on app rerun chat_box.output_messages() - chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter " + chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 " def on_feedback( feedback, @@ -256,139 +296,142 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): } if prompt := st.chat_input(chat_input_placeholder, key="prompt"): - history = get_messages_history(history_len) - chat_box.user_say(prompt) - if dialogue_mode == "LLM 对话": - chat_box.ai_say("正在思考...") - text = "" - message_id = "" - r = api.chat_chat(prompt, - history=history, - conversation_id=conversation_id, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature) - for t in r: - if error_msg := check_error_msg(t): # check whether error occured - st.error(error_msg) - break - text += t.get("text", "") - chat_box.update_msg(text) - message_id = t.get("message_id", "") + if parse_command(text=prompt, modal=modal): # 用户输入自定义命令 + st.rerun() + else: + history = get_messages_history(history_len) + chat_box.user_say(prompt) + if dialogue_mode == "LLM 对话": + chat_box.ai_say("正在思考...") + text = "" + message_id = "" + r = api.chat_chat(prompt, + history=history, + conversation_id=conversation_id, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature) + for t in r: + if error_msg := check_error_msg(t): # check whether error occured + st.error(error_msg) + break + text += t.get("text", "") + chat_box.update_msg(text) + message_id = t.get("message_id", "") - metadata = { - "message_id": message_id, - } - chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标 - chat_box.show_feedback(**feedback_kwargs, - key=message_id, - on_submit=on_feedback, - kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) + metadata = { + "message_id": message_id, + } + chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标 + chat_box.show_feedback(**feedback_kwargs, + key=message_id, + on_submit=on_feedback, + kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) - elif dialogue_mode == "自定义Agent问答": - if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL): + elif dialogue_mode == "自定义Agent问答": + if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL): + chat_box.ai_say([ + f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n", + Markdown("...", in_expander=True, title="思考过程", state="complete"), + + ]) + else: + chat_box.ai_say([ + f"正在思考...", + Markdown("...", in_expander=True, title="思考过程", state="complete"), + + ]) + text = "" + ans = "" + for d in api.agent_chat(prompt, + history=history, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature, + ): + try: + d = json.loads(d) + except: + pass + if error_msg := check_error_msg(d): # check whether error occured + st.error(error_msg) + if chunk := d.get("answer"): + text += chunk + chat_box.update_msg(text, element_index=1) + if chunk := d.get("final_answer"): + ans += chunk + chat_box.update_msg(ans, element_index=0) + if chunk := d.get("tools"): + text += "\n\n".join(d.get("tools", [])) + chat_box.update_msg(text, element_index=1) + chat_box.update_msg(ans, element_index=0, streaming=False) + chat_box.update_msg(text, element_index=1, streaming=False) + elif dialogue_mode == "知识库问答": chat_box.ai_say([ - f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n", - Markdown("...", in_expander=True, title="思考过程", state="complete"), - + f"正在查询知识库 `{selected_kb}` ...", + Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"), ]) - else: + text = "" + for d in api.knowledge_base_chat(prompt, + knowledge_base_name=selected_kb, + top_k=kb_top_k, + score_threshold=score_threshold, + history=history, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature): + if error_msg := check_error_msg(d): # check whether error occured + st.error(error_msg) + elif chunk := d.get("answer"): + text += chunk + chat_box.update_msg(text, element_index=0) + chat_box.update_msg(text, element_index=0, streaming=False) + chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) + elif dialogue_mode == "文件对话": + if st.session_state["file_chat_id"] is None: + st.error("请先上传文件再进行对话") + st.stop() chat_box.ai_say([ - f"正在思考...", - Markdown("...", in_expander=True, title="思考过程", state="complete"), - + f"正在查询文件 `{st.session_state['file_chat_id']}` ...", + Markdown("...", in_expander=True, title="文件匹配结果", state="complete"), ]) - text = "" - ans = "" - for d in api.agent_chat(prompt, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature, - ): - try: - d = json.loads(d) - except: - pass - if error_msg := check_error_msg(d): # check whether error occured - st.error(error_msg) - if chunk := d.get("answer"): - text += chunk - chat_box.update_msg(text, element_index=1) - if chunk := d.get("final_answer"): - ans += chunk - chat_box.update_msg(ans, element_index=0) - if chunk := d.get("tools"): - text += "\n\n".join(d.get("tools", [])) - chat_box.update_msg(text, element_index=1) - chat_box.update_msg(ans, element_index=0, streaming=False) - chat_box.update_msg(text, element_index=1, streaming=False) - elif dialogue_mode == "知识库问答": - chat_box.ai_say([ - f"正在查询知识库 `{selected_kb}` ...", - Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"), - ]) - text = "" - for d in api.knowledge_base_chat(prompt, - knowledge_base_name=selected_kb, - top_k=kb_top_k, - score_threshold=score_threshold, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature): - if error_msg := check_error_msg(d): # check whether error occured - st.error(error_msg) - elif chunk := d.get("answer"): - text += chunk - chat_box.update_msg(text, element_index=0) - chat_box.update_msg(text, element_index=0, streaming=False) - chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) - elif dialogue_mode == "文件对话": - if st.session_state["file_chat_id"] is None: - st.error("请先上传文件再进行对话") - st.stop() - chat_box.ai_say([ - f"正在查询文件 `{st.session_state['file_chat_id']}` ...", - Markdown("...", in_expander=True, title="文件匹配结果", state="complete"), - ]) - text = "" - for d in api.file_chat(prompt, - knowledge_id=st.session_state["file_chat_id"], - top_k=kb_top_k, - score_threshold=score_threshold, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature): - if error_msg := check_error_msg(d): # check whether error occured - st.error(error_msg) - elif chunk := d.get("answer"): - text += chunk - chat_box.update_msg(text, element_index=0) - chat_box.update_msg(text, element_index=0, streaming=False) - chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) - elif dialogue_mode == "搜索引擎问答": - chat_box.ai_say([ - f"正在执行 `{search_engine}` 搜索...", - Markdown("...", in_expander=True, title="网络搜索结果", state="complete"), - ]) - text = "" - for d in api.search_engine_chat(prompt, - search_engine_name=search_engine, - top_k=se_top_k, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature, - split_result=se_top_k > 1): - if error_msg := check_error_msg(d): # check whether error occured - st.error(error_msg) - elif chunk := d.get("answer"): - text += chunk - chat_box.update_msg(text, element_index=0) - chat_box.update_msg(text, element_index=0, streaming=False) - chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) + text = "" + for d in api.file_chat(prompt, + knowledge_id=st.session_state["file_chat_id"], + top_k=kb_top_k, + score_threshold=score_threshold, + history=history, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature): + if error_msg := check_error_msg(d): # check whether error occured + st.error(error_msg) + elif chunk := d.get("answer"): + text += chunk + chat_box.update_msg(text, element_index=0) + chat_box.update_msg(text, element_index=0, streaming=False) + chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) + elif dialogue_mode == "搜索引擎问答": + chat_box.ai_say([ + f"正在执行 `{search_engine}` 搜索...", + Markdown("...", in_expander=True, title="网络搜索结果", state="complete"), + ]) + text = "" + for d in api.search_engine_chat(prompt, + search_engine_name=search_engine, + top_k=se_top_k, + history=history, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature, + split_result=se_top_k > 1): + if error_msg := check_error_msg(d): # check whether error occured + st.error(error_msg) + elif chunk := d.get("answer"): + text += chunk + chat_box.update_msg(text, element_index=0) + chat_box.update_msg(text, element_index=0, streaming=False) + chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) if st.session_state.get("need_rerun"): st.session_state["need_rerun"] = False diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 57d6202d..53484917 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -34,14 +34,19 @@ def config_aggrid( use_checkbox=use_checkbox, # pre_selected_rows=st.session_state.get("selected_rows", [0]), ) + gb.configure_pagination( + enabled=True, + paginationAutoPageSize=False, + paginationPageSize=10 + ) return gb def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]: - ''' + """ check whether a doc file exists in local knowledge base folder. return the file's name and path if it exists. - ''' + """ if selected_rows: file_name = selected_rows[0]["file_name"] file_path = get_file_path(kb, file_name) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 64401447..75e14199 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -85,7 +85,7 @@ class ApiRequest: ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: - print(kwargs) + # print(kwargs) if stream: return self.client.stream("POST", url, data=data, json=json, **kwargs) else: @@ -134,7 +134,7 @@ class ApiRequest: if as_json: try: data = json.loads(chunk) - pprint(data, depth=1) + # pprint(data, depth=1) yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" @@ -166,7 +166,7 @@ class ApiRequest: if as_json: try: data = json.loads(chunk) - pprint(data, depth=1) + # pprint(data, depth=1) yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" @@ -276,8 +276,8 @@ class ApiRequest: "max_tokens": max_tokens, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post( "/chat/fastchat", @@ -288,23 +288,25 @@ class ApiRequest: return self._httpx_stream2generator(response) def chat_chat( - self, - query: str, - conversation_id: str = None, - history: List[Dict] = [], - stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, - max_tokens: int = None, - prompt_name: str = "default", - **kwargs, + self, + query: str, + conversation_id: str = None, + history_len: int = -1, + history: List[Dict] = [], + stream: bool = True, + model: str = LLM_MODELS[0], + temperature: float = TEMPERATURE, + max_tokens: int = None, + prompt_name: str = "default", + **kwargs, ): ''' - 对应api.py/chat/chat接口 #TODO: 考虑是否返回json + 对应api.py/chat/chat接口 ''' data = { "query": query, "conversation_id": conversation_id, + "history_len": history_len, "history": history, "stream": stream, "model_name": model, @@ -313,8 +315,8 @@ class ApiRequest: "prompt_name": prompt_name, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post("/chat/chat", json=data, stream=True, **kwargs) return self._httpx_stream2generator(response, as_json=True) @@ -342,8 +344,8 @@ class ApiRequest: "prompt_name": prompt_name, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post("/chat/agent_chat", json=data, stream=True) return self._httpx_stream2generator(response) @@ -377,8 +379,8 @@ class ApiRequest: "prompt_name": prompt_name, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post( "/chat/knowledge_base_chat", @@ -452,8 +454,8 @@ class ApiRequest: "prompt_name": prompt_name, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post( "/chat/file_chat", @@ -491,8 +493,8 @@ class ApiRequest: "split_result": split_result, } - print(f"received input message:") - pprint(data) + # print(f"received input message:") + # pprint(data) response = self.post( "/chat/search_engine_chat",