diff --git a/README.md b/README.md
index 2883636e..67aa6ef8 100644
--- a/README.md
+++ b/README.md
@@ -148,7 +148,7 @@ $ python startup.py -a
[](https://t.me/+RjliQ3jnJ1YyN2E9)
### 项目交流群
-
+
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
diff --git a/configs/__init__.py b/configs/__init__.py
index a4bf7665..0c862507 100644
--- a/configs/__init__.py
+++ b/configs/__init__.py
@@ -5,4 +5,4 @@ from .server_config import *
from .prompt_config import *
-VERSION = "v0.2.8-preview"
+VERSION = "v0.2.9-preview"
diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example
index 9c727b77..c9ff4318 100644
--- a/configs/kb_config.py.example
+++ b/configs/kb_config.py.example
@@ -106,7 +106,7 @@ kbs_config = {
# TextSplitter配置项,如果你不明白其中的含义,就不要修改。
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
- "source": "huggingface", ## 选择tiktoken则使用openai的方法
+ "source": "huggingface", # 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
diff --git a/configs/model_config.py.example b/configs/model_config.py.example
index e08b5a00..19084142 100644
--- a/configs/model_config.py.example
+++ b/configs/model_config.py.example
@@ -15,9 +15,11 @@ EMBEDDING_DEVICE = "auto"
EMBEDDING_KEYWORD_FILE = "keywords.txt"
EMBEDDING_MODEL_OUTPUT_PATH = "output"
-# 要运行的 LLM 名称,可以包括本地模型和在线模型。
-# 第一个将作为 API 和 WEBUI 的默认模型
-LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
+# 要运行的 LLM 名称,可以包括本地模型和在线模型。列表中本地模型将在启动项目时全部加载。
+# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
+# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
+# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
+LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",
# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0])
Agent_MODEL = None
@@ -112,10 +114,10 @@ ONLINE_LLM_MODEL = {
"api_key": "",
"provider": "AzureWorker",
},
-
+
# 昆仑万维天工 API https://model-platform.tiangong.cn/
"tiangong-api": {
- "version":"SkyChat-MegaVerse",
+ "version": "SkyChat-MegaVerse",
"api_key": "",
"secret_key": "",
"provider": "TianGongWorker",
@@ -163,6 +165,25 @@ MODEL_PATH = {
"chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
+ "chatglm3-6b-base": "THUDM/chatglm3-6b-base",
+
+ "Qwen-1_8B": "Qwen/Qwen-1_8B",
+ "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
+ "Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
+ "Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
+
+ "Qwen-7B": "Qwen/Qwen-7B",
+ "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
+
+ "Qwen-14B": "Qwen/Qwen-14B",
+ "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
+ "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
+ "Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
+
+ "Qwen-72B": "Qwen/Qwen-72B",
+ "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
+ "Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
+ "Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat",
@@ -204,18 +225,11 @@ MODEL_PATH = {
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
- "Qwen-7B": "Qwen/Qwen-7B",
- "Qwen-14B": "Qwen/Qwen-14B",
- "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
- "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
- "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", # 确保已经安装了auto-gptq optimum flash-attn
- "Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4", # 确保已经安装了auto-gptq optimum flash-attn
-
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
- "Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", # 更多01-ai模型尚未进行测试。如果需要使用,请自行测试。
+ "Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
},
}
@@ -242,11 +256,11 @@ VLLM_MODEL_DICT = {
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
# 注意:bloom系列的tokenizer与model是分离的,因此虽然vllm支持,但与fschat框架不兼容
- # "bloom":"bigscience/bloom",
- # "bloomz":"bigscience/bloomz",
- # "bloomz-560m":"bigscience/bloomz-560m",
- # "bloomz-7b1":"bigscience/bloomz-7b1",
- # "bloomz-1b7":"bigscience/bloomz-1b7",
+ # "bloom": "bigscience/bloom",
+ # "bloomz": "bigscience/bloomz",
+ # "bloomz-560m": "bigscience/bloomz-560m",
+ # "bloomz-7b1": "bigscience/bloomz-7b1",
+ # "bloomz-1b7": "bigscience/bloomz-1b7",
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
@@ -273,10 +287,23 @@ VLLM_MODEL_DICT = {
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
+ "Qwen-1_8B": "Qwen/Qwen-1_8B",
+ "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
+ "Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
+ "Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
+
"Qwen-7B": "Qwen/Qwen-7B",
- "Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
+
+ "Qwen-14B": "Qwen/Qwen-14B",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
+ "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
+ "Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
+
+ "Qwen-72B": "Qwen/Qwen-72B",
+ "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
+ "Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
+ "Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example
index dd0dc2a0..6fb6996c 100644
--- a/configs/prompt_config.py.example
+++ b/configs/prompt_config.py.example
@@ -1,7 +1,6 @@
# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
# 本配置文件支持热加载,修改prompt模板后无需重启服务。
-
# LLM对话支持的变量:
# - input: 用户输入内容
@@ -17,125 +16,112 @@
# - input: 用户输入内容
# - agent_scratchpad: Agent的思维记录
-PROMPT_TEMPLATES = {}
+PROMPT_TEMPLATES = {
+ "llm_chat": {
+ "default":
+ '{{ input }}',
-PROMPT_TEMPLATES["llm_chat"] = {
-"default": "{{ input }}",
-"with_history":
-"""The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
+ "with_history":
+ 'The following is a friendly conversation between a human and an AI. '
+ 'The AI is talkative and provides lots of specific details from its context. '
+ 'If the AI does not know the answer to a question, it truthfully says it does not know.\n\n'
+ 'Current conversation:\n'
+ '{history}\n'
+ 'Human: {input}\n'
+ 'AI:',
-Current conversation:
-{history}
-Human: {input}
-AI:""",
-"py":
-"""
-你是一个聪明的代码助手,请你给我写出简单的py代码。 \n
-{{ input }}
-""",
-}
-
-PROMPT_TEMPLATES["knowledge_base_chat"] = {
-"default":
-"""
-<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 指令>
-<已知信息>{{ context }}已知信息>、
-<问题>{{ question }}问题>
-""",
-"text":
-"""
-<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 指令>
-<已知信息>{{ context }}已知信息>、
-<问题>{{ question }}问题>
-""",
-"Empty": # 搜不到知识库的时候使用
-"""
-请你回答我的问题:
-{{ question }}
-\n
-""",
-}
-PROMPT_TEMPLATES["search_engine_chat"] = {
-"default":
-"""
-<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 指令>
-<已知信息>{{ context }}已知信息>
-<问题>{{ question }}问题>
-""",
-"search":
-"""
-<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 指令>
-<已知信息>{{ context }}已知信息>、
-<问题>{{ question }}问题>
-""",
-}
-PROMPT_TEMPLATES["agent_chat"] = {
-"default":
-"""
-Answer the following questions as best you can. If it is in order, you can use some tools appropriately.You have access to the following tools:
-
-{tools}
-
-Please note that the "知识库查询工具" is information about the "西交利物浦大学" ,and if a question is asked about it, you must answer with the knowledge base,
-Please note that the "天气查询工具" can only be used once since Question begin.
-
-Use the following format:
-Question: the input question you must answer1
-Thought: you should always think about what to do and what tools to use.
-Action: the action to take, should be one of [{tool_names}]
-Action Input: the input to the action
-Observation: the result of the action
-... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
-Thought: I now know the final answer
-Final Answer: the final answer to the original input question
-Begin!
-
-history: {history}
-
-Question: {input}
-
-Thought: {agent_scratchpad}
-""",
-
-"ChatGLM3":
-"""
-You can answer using the tools, or answer directly using your knowledge without using the tools.Respond to the human as helpfully and accurately as possible.
-You have access to the following tools:
-{tools}
-Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
-Valid "action" values: "Final Answer" or [{tool_names}]
-Provide only ONE action per $JSON_BLOB, as shown:
-
-```
-{{{{
- "action": $TOOL_NAME,
- "action_input": $INPUT
-}}}}
-```
-
-Follow this format:
-
-Question: input question to answer
-Thought: consider previous and subsequent steps
-Action:
-```
-$JSON_BLOB
-```
-Observation: action result
-... (repeat Thought/Action/Observation N times)
-Thought: I know what to respond
-Action:
-```
-{{{{
- "action": "Final Answer",
- "action_input": "Final response to human"
-}}}}
-Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
-
-history: {history}
-
-Question: {input}
-
-Thought: {agent_scratchpad}
-""",
+ "py":
+ '你是一个聪明的代码助手,请你给我写出简单的py代码。 \n'
+ '{{ input }}',
+ },
+
+
+ "knowledge_base_chat": {
+ "default":
+ '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,'
+ '不允许在答案中添加编造成分,答案请使用中文。 指令>\n'
+ '<已知信息>{{ context }}已知信息>\n'
+ '<问题>{{ question }}问题>\n',
+
+ "text":
+ '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 指令>\n'
+ '<已知信息>{{ context }}已知信息>\n'
+ '<问题>{{ question }}问题>\n',
+
+ "empty": # 搜不到知识库的时候使用
+ '请你回答我的问题:\n'
+ '{{ question }}\n\n',
+ },
+
+
+ "search_engine_chat": {
+ "default":
+ '<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。'
+ '如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 指令>\n'
+ '<已知信息>{{ context }}已知信息>\n'
+ '<问题>{{ question }}问题>\n',
+
+ "search":
+ '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 指令>\n'
+ '<已知信息>{{ context }}已知信息>\n'
+ '<问题>{{ question }}问题>\n',
+ },
+
+
+ "agent_chat": {
+ "default":
+ 'Answer the following questions as best you can. If it is in order, you can use some tools appropriately. '
+ 'You have access to the following tools:\n\n'
+ '{tools}\n\n'
+ 'Use the following format:\n'
+ 'Question: the input question you must answer1\n'
+ 'Thought: you should always think about what to do and what tools to use.\n'
+ 'Action: the action to take, should be one of [{tool_names}]\n'
+ 'Action Input: the input to the action\n'
+ 'Observation: the result of the action\n'
+ '... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n'
+ 'Thought: I now know the final answer\n'
+ 'Final Answer: the final answer to the original input question\n'
+ 'Begin!\n\n'
+ 'history: {history}\n\n'
+ 'Question: {input}\n\n'
+ 'Thought: {agent_scratchpad}\n',
+
+ "ChatGLM3":
+ 'You can answer using the tools, or answer directly using your knowledge without using the tools. '
+ 'Respond to the human as helpfully and accurately as possible.\n'
+ 'You have access to the following tools:\n'
+ '{tools}\n'
+ 'Use a json blob to specify a tool by providing an action key (tool name) '
+ 'and an action_input key (tool input).\n'
+ 'Valid "action" values: "Final Answer" or [{tool_names}]'
+ 'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
+ '```\n'
+ '{{{{\n'
+ ' "action": $TOOL_NAME,\n'
+ ' "action_input": $INPUT\n'
+ '}}}}\n'
+ '```\n\n'
+ 'Follow this format:\n\n'
+ 'Question: input question to answer\n'
+ 'Thought: consider previous and subsequent steps\n'
+ 'Action:\n'
+ '```\n'
+ '$JSON_BLOB\n'
+ '```\n'
+ 'Observation: action result\n'
+ '... (repeat Thought/Action/Observation N times)\n'
+ 'Thought: I know what to respond\n'
+ 'Action:\n'
+ '```\n'
+ '{{{{\n'
+ ' "action": "Final Answer",\n'
+ ' "action_input": "Final response to human"\n'
+ '}}}}\n'
+ 'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. '
+ 'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n'
+ 'history: {history}\n\n'
+ 'Question: {input}\n\n'
+ 'Thought: {agent_scratchpad}',
+ }
}
diff --git a/configs/server_config.py.example b/configs/server_config.py.example
index 2e27c400..7fa0c412 100644
--- a/configs/server_config.py.example
+++ b/configs/server_config.py.example
@@ -66,7 +66,7 @@ FSCHAT_MODEL_WORKERS = {
# "no_register": False,
# "embed_in_truncate": False,
- # 以下为vllm_woker配置参数,注意使用vllm必须有gpu,仅在Linux测试通过
+ # 以下为vllm_worker配置参数,注意使用vllm必须有gpu,仅在Linux测试通过
# tokenizer = model_path # 如果tokenizer与model_path不一致在此处添加
# 'tokenizer_mode':'auto',
@@ -93,14 +93,14 @@ FSCHAT_MODEL_WORKERS = {
},
# 可以如下示例方式更改默认配置
- # "Qwen-7B-Chat": { # 使用default中的IP和端口
+ # "Qwen-1_8B-Chat": { # 使用default中的IP和端口
# "device": "cpu",
# },
- "chatglm3-6b": { # 使用default中的IP和端口
+ "chatglm3-6b": { # 使用default中的IP和端口
"device": "cuda",
},
- #以下配置可以不用修改,在model_config中设置启动的模型
+ # 以下配置可以不用修改,在model_config中设置启动的模型
"zhipu-api": {
"port": 21001,
},
diff --git a/document_loaders/myimgloader.py b/document_loaders/myimgloader.py
index 86481924..e09c6172 100644
--- a/document_loaders/myimgloader.py
+++ b/document_loaders/myimgloader.py
@@ -1,13 +1,13 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
+from document_loaders.ocr import get_ocr
class RapidOCRLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def img2text(filepath):
- from rapidocr_onnxruntime import RapidOCR
resp = ""
- ocr = RapidOCR()
+ ocr = get_ocr()
result, _ = ocr(filepath)
if result:
ocr_result = [line[1] for line in result]
diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py
index 6cb77267..51778b89 100644
--- a/document_loaders/mypdfloader.py
+++ b/document_loaders/mypdfloader.py
@@ -1,5 +1,6 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
+from document_loaders.ocr import get_ocr
import tqdm
@@ -7,9 +8,8 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def pdf2text(filepath):
import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆
- from rapidocr_onnxruntime import RapidOCR
import numpy as np
- ocr = RapidOCR()
+ ocr = get_ocr()
doc = fitz.open(filepath)
resp = ""
diff --git a/document_loaders/ocr.py b/document_loaders/ocr.py
new file mode 100644
index 00000000..2b66dd35
--- /dev/null
+++ b/document_loaders/ocr.py
@@ -0,0 +1,18 @@
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+ try:
+ from rapidocr_paddle import RapidOCR
+ except ImportError:
+ from rapidocr_onnxruntime import RapidOCR
+
+
+def get_ocr(use_cuda: bool = True) -> "RapidOCR":
+ try:
+ from rapidocr_paddle import RapidOCR
+ ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda)
+ except ImportError:
+ from rapidocr_onnxruntime import RapidOCR
+ ocr = RapidOCR()
+ return ocr
diff --git a/img/qr_code_71.jpg b/img/qr_code_71.jpg
deleted file mode 100644
index 78828f7b..00000000
Binary files a/img/qr_code_71.jpg and /dev/null differ
diff --git a/img/qr_code_72.jpg b/img/qr_code_72.jpg
deleted file mode 100644
index 10a504b0..00000000
Binary files a/img/qr_code_72.jpg and /dev/null differ
diff --git a/img/qr_code_73.jpg b/img/qr_code_73.jpg
deleted file mode 100644
index 3e5ec45a..00000000
Binary files a/img/qr_code_73.jpg and /dev/null differ
diff --git a/img/qr_code_74.jpg b/img/qr_code_74.jpg
deleted file mode 100644
index 1f665f65..00000000
Binary files a/img/qr_code_74.jpg and /dev/null differ
diff --git a/img/qr_code_76.jpg b/img/qr_code_76.jpg
new file mode 100644
index 00000000..d952dc19
Binary files /dev/null and b/img/qr_code_76.jpg differ
diff --git a/knowledge_base/samples/content/wiki b/knowledge_base/samples/content/wiki
index f789e5dd..9a3fa7a7 160000
--- a/knowledge_base/samples/content/wiki
+++ b/knowledge_base/samples/content/wiki
@@ -1 +1 @@
-Subproject commit f789e5dde10f91136012f3470c020c8d34572436
+Subproject commit 9a3fa7a77f8748748b1c656fe8919ad5c4c63e3f
diff --git a/requirements.txt b/requirements.txt
index 036489e8..b5cd996d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,63 +1,75 @@
# API requirements
-langchain>=0.0.334
+langchain==0.0.344
langchain-experimental>=0.0.42
-fschat[model_worker]>=0.2.33
+pydantic==1.10.13
+fschat>=0.2.33
xformers>=0.0.22.post7
-openai~=0.28.1
+openai>=1.3.7
sentence_transformers
transformers>=4.35.2
-torch==2.1.0 ##on win, install the cuda version manually if you want use gpu
-torchvision #on win, install the cuda version manually if you want use gpu
-torchaudio #on win, install the cuda version manually if you want use gpu
+torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/
+torchvision #on Windows system, install the cuda version manually from https://pytorch.org/
+torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/
fastapi>=0.104
nltk>=3.8.1
-uvicorn~=0.23.1
+uvicorn>=0.24.0.post1
starlette~=0.27.0
-pydantic<2
-unstructured[all-docs]>=0.11.0
+unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
-faiss-cpu
-accelerate
-spacy
+faiss-cpu # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
+accelerate>=0.24.1
+spacy>=3.7.2
PyMuPDF
rapidocr_onnxruntime
-
-requests
-pathlib
-pytest
-numexpr
-strsimpy
-markdownify
-tiktoken
-tqdm
+requests>=2.31.0
+pathlib>=1.0.1
+pytest>=7.4.3
+numexpr>=2.8.7
+strsimpy>=0.2.1
+markdownify>=0.11.6
+tiktoken>=0.5.1
+tqdm>=4.66.1
websockets
numpy~=1.24.4
pandas~=2.0.3
-einops
+einops>=0.7.0
transformers_stream_generator==0.0.4
-
vllm==0.2.2; sys_platform == "linux"
-# online api libs dependencies
+# optional document loaders
+# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files
+# jq # for .json and .jsonl files. suggest `conda install jq` on windows
+# html2text # for .enex files
+# beautifulsoup4 # for .mhtml files
+# pysrt # for .srt files
+
+# Online api libs dependencies
# zhipuai>=1.0.7
# dashscope>=1.10.0
# qianfan>=0.2.0
# volcengine>=1.0.106
-
-# uncomment libs if you want to use corresponding vector store
-# pymilvus==2.1.3 # requires milvus==2.1.3
+# pymilvus>=2.3.3
# psycopg2
-# pgvector
+# pgvector>=0.2.4
+
+# Agent and Search Tools
+
+arxiv>=2.0.0
+youtube-search>=2.1.2
+duckduckgo-search>=3.9.9
+metaphor-python>=0.1.23
# WebUI requirements
-streamlit~=1.28.2 # # on win, make sure write its path in environment variable
+streamlit>=1.29.0 # do remember to add streamlit to environment variables if you use windows
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.2.3
streamlit-chatbox>=1.1.11
+streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
-httpx[brotli,http2,socks]~=0.24.1
-watchdog
+httpx[brotli,http2,socks]>=0.25.2
+watchdog>=3.0.0
+
diff --git a/requirements_api.txt b/requirements_api.txt
index 85b23072..ec1005f9 100644
--- a/requirements_api.txt
+++ b/requirements_api.txt
@@ -1,52 +1,65 @@
# API requirements
-langchain>=0.0.334
+langchain==0.0.344
langchain-experimental>=0.0.42
-fschat[model_worker]>=0.2.33
+pydantic==1.10.13
+fschat>=0.2.33
xformers>=0.0.22.post7
-openai~=0.28.1
+openai>=1.3.7
sentence_transformers
transformers>=4.35.2
-torch==2.1.0
-torchvision
-torchaudio
+torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/
+torchvision #on Windows system, install the cuda version manually from https://pytorch.org/
+torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/
fastapi>=0.104
nltk>=3.8.1
-uvicorn~=0.23.1
+uvicorn>=0.24.0.post1
starlette~=0.27.0
-pydantic<2
-unstructured[all-docs]>=0.11.0
+unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
faiss-cpu
accelerate>=0.24.1
-spacy
+spacy>=3.7.2
PyMuPDF
rapidocr_onnxruntime
-
-requests
-pathlib
-pytest
-numexpr
-strsimpy
-markdownify
-tiktoken
-tqdm
+requests>=2.31.0
+pathlib>=1.0.1
+pytest>=7.4.3
+numexpr>=2.8.7
+strsimpy>=0.2.1
+markdownify>=0.11.6
+tiktoken>=0.5.1
+tqdm>=4.66.1
websockets
numpy~=1.24.4
pandas~=2.0.3
-einops
-transformers_stream_generator>=0.0.4
+einops>=0.7.0
+transformers_stream_generator==0.0.4
+vllm==0.2.2; sys_platform == "linux"
+httpx[brotli,http2,socks]>=0.25.2
-vllm>=0.2.0; sys_platform == "linux"
-# online api libs
-zhipuai
-dashscope>=1.10.0 # qwen
-qianfan
-# volcengine>=1.0.106 # fangzhou
+# optional document loaders
+# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files
+# jq # for .json and .jsonl files. suggest `conda install jq` on windows
+# html2text # for .enex files
+# beautifulsoup4 # for .mhtml files
+# pysrt # for .srt files
-# uncomment libs if you want to use corresponding vector store
-# pymilvus==2.1.3 # requires milvus==2.1.3
+# Online api libs dependencies
+
+# zhipuai>=1.0.7
+# dashscope>=1.10.0
+# qianfan>=0.2.0
+# volcengine>=1.0.106
+# pymilvus>=2.3.3
# psycopg2
-# pgvector
+# pgvector>=0.2.4
+
+# Agent and Search Tools
+
+arxiv>=2.0.0
+youtube-search>=2.1.2
+duckduckgo-search>=3.9.9
+metaphor-python>=0.1.23
\ No newline at end of file
diff --git a/requirements_lite.txt b/requirements_lite.txt
index 1cfcb81b..ad01376c 100644
--- a/requirements_lite.txt
+++ b/requirements_lite.txt
@@ -1,62 +1,70 @@
-langchain>=0.0.334
-fschat>=0.2.32
-openai
-# sentence_transformers
-# transformers>=4.35.0
-# torch>=2.0.1
-# torchvision
-# torchaudio
+langchain==0.0.344
+pydantic==1.10.13
+fschat>=0.2.33
+openai>=1.3.7
fastapi>=0.104.1
python-multipart
nltk~=3.8.1
-uvicorn~=0.23.1
+uvicorn>=0.24.0.post1
starlette~=0.27.0
-pydantic~=1.10.11
-unstructured[docx,csv]>=0.10.4 # add pdf if need
+unstructured[docx,csv]==0.11.0 # add pdf if need
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
+numexpr>=2.8.7
+strsimpy>=0.2.1
+
faiss-cpu
-# accelerate
-# spacy
+# accelerate>=0.24.1
+# spacy>=3.7.2
# PyMuPDF==1.22.5 # install if need pdf
# rapidocr_onnxruntime>=1.3.2 # install if need pdf
+# optional document loaders
+# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files
+# jq # for .json and .jsonl files. suggest `conda install jq` on windows
+# html2text # for .enex files
+# beautifulsoup4 # for .mhtml files
+# pysrt # for .srt files
+
requests
pathlib
pytest
# scikit-learn
# numexpr
-# vllm==0.1.7; sys_platform == "linux"
+# vllm==0.2.2; sys_platform == "linux"
# online api libs
-zhipuai
-dashscope>=1.10.0 # qwen
-# qianfan
+
+zhipuai>=1.0.7 # zhipu
+# dashscope>=1.10.0 # qwen
# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2
-# pgvector
+# pgvector>=0.2.4
numpy~=1.24.4
pandas~=2.0.3
-streamlit~=1.28.1
+streamlit>=1.29.0
streamlit-option-menu>=0.3.6
-streamlit-antd-components>=0.1.11
-streamlit-chatbox==1.1.11
+streamlit-antd-components>=0.2.3
+streamlit-chatbox>=1.1.11
+streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
-httpx~=0.24.1
-watchdog
-tqdm
+httpx[brotli,http2,socks]>=0.25.2
+watchdog>=3.0.0
+tqdm>=4.66.1
websockets
+einops>=0.7.0
+
# tiktoken
-einops
-# scipy
+# scipy>=1.11.4
# transformers_stream_generator==0.0.4
-# search engine libs
-duckduckgo-search
-metaphor-python
-strsimpy
-markdownify
+# Agent and Search Tools
+
+arxiv>=2.0.0
+youtube-search>=2.1.2
+duckduckgo-search>=3.9.9
+metaphor-python>=0.1.23
\ No newline at end of file
diff --git a/requirements_webui.txt b/requirements_webui.txt
index 0fe3dc8e..3c16c32e 100644
--- a/requirements_webui.txt
+++ b/requirements_webui.txt
@@ -1,10 +1,10 @@
# WebUI requirements
-streamlit~=1.28.2
+streamlit>=1.29.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.2.3
streamlit-chatbox>=1.1.11
+streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
-httpx[brotli,http2,socks]~=0.24.1
-watchdog
-
+httpx[brotli,http2,socks]>=0.25.2
+watchdog>=3.0.0
diff --git a/server/agent/tools/search_knowledgebase_complex.py b/server/agent/tools/search_knowledgebase_complex.py
index 0b7884bd..af4d9116 100644
--- a/server/agent/tools/search_knowledgebase_complex.py
+++ b/server/agent/tools/search_knowledgebase_complex.py
@@ -170,7 +170,6 @@ class LLMKnowledgeChain(LLMChain):
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
except:
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
- print(queries)
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
output = self._evaluate_expression(queries)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
diff --git a/server/api.py b/server/api.py
index 5fc80c46..7444d4bf 100644
--- a/server/api.py
+++ b/server/api.py
@@ -81,6 +81,8 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
# 知识库相关接口
mount_knowledge_routes(app)
+ # 摘要相关接口
+ mount_filename_summary_routes(app)
# LLM模型相关接口
app.post("/llm_model/list_running_models",
@@ -230,6 +232,26 @@ def mount_knowledge_routes(app: FastAPI):
)(upload_temp_docs)
+def mount_filename_summary_routes(app: FastAPI):
+ from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
+ summary_doc_ids_to_vector_store)
+
+ app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store",
+ tags=["Knowledge kb_summary_api Management"],
+ summary="单个知识库根据文件名称摘要"
+ )(summary_file_to_vector_store)
+ app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store",
+ tags=["Knowledge kb_summary_api Management"],
+ summary="单个知识库根据doc_ids摘要",
+ response_model=BaseResponse,
+ )(summary_doc_ids_to_vector_store)
+ app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store",
+ tags=["Knowledge kb_summary_api Management"],
+ summary="重建单个知识库文件摘要"
+ )(recreate_summary_vector_store)
+
+
+
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py
index 8ae38db6..51d4d1f6 100644
--- a/server/chat/agent_chat.py
+++ b/server/chat/agent_chat.py
@@ -43,7 +43,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
+ nonlocal max_tokens
callback = CustomAsyncIteratorCallbackHandler()
+ if isinstance(max_tokens, int) and max_tokens <= 0:
+ max_tokens = None
+
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
diff --git a/server/chat/chat.py b/server/chat/chat.py
index acf3ec0c..bac82f8e 100644
--- a/server/chat/chat.py
+++ b/server/chat/chat.py
@@ -19,6 +19,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"),
+ history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
@@ -34,18 +35,20 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
async def chat_iterator() -> AsyncIterable[str]:
- nonlocal history
+ nonlocal history, max_tokens
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
memory = None
- message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
# 负责保存llm response到message db
+ message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=query)
callbacks.append(conversation_callback)
+ if isinstance(max_tokens, int) and max_tokens <= 0:
+ max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
@@ -54,18 +57,24 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks=callbacks,
)
- if not conversation_id:
+ if history: # 优先使用前端传入的历史消息
history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
- else:
+ elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
prompt = get_prompt_template("llm_chat", "with_history")
chat_prompt = PromptTemplate.from_template(prompt)
# 根据conversation_id 获取message 列表进而拼凑 memory
- memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model)
+ memory = ConversationBufferDBMemory(conversation_id=conversation_id,
+ llm=model,
+ message_limit=history_len)
+ else:
+ prompt_template = get_prompt_template("llm_chat", prompt_name)
+ input_msg = History(role="user", content=prompt_template).to_msg_template(False)
+ chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
diff --git a/server/chat/completion.py b/server/chat/completion.py
index ee5e2d12..6e5827dd 100644
--- a/server/chat/completion.py
+++ b/server/chat/completion.py
@@ -6,7 +6,8 @@ from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, Optional
import asyncio
-from langchain.prompts.chat import PromptTemplate
+from langchain.prompts import PromptTemplate
+
from server.utils import get_prompt_template
@@ -27,7 +28,11 @@ async def completion(query: str = Body(..., description="用户输入", examples
prompt_name: str = prompt_name,
echo: bool = echo,
) -> AsyncIterable[str]:
+ nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
+ if isinstance(max_tokens, int) and max_tokens <= 0:
+ max_tokens = None
+
model = get_OpenAI(
model_name=model_name,
temperature=temperature,
diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py
index ea3475a0..a58cb29b 100644
--- a/server/chat/file_chat.py
+++ b/server/chat/file_chat.py
@@ -113,7 +113,11 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
+ nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
+ if isinstance(max_tokens, int) and max_tokens <= 0:
+ max_tokens = None
+
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
@@ -121,7 +125,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
callbacks=[callback],
)
embed_func = EmbeddingsFunAdapter()
- embeddings = embed_func.embed_query(query)
+ embeddings = await embed_func.aembed_query(query)
with memo_faiss_pool.acquire(knowledge_id) as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
docs = [x[0] for x in docs]
diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py
index 0ea99a6e..a3ab68b1 100644
--- a/server/chat/knowledge_base_chat.py
+++ b/server/chat/knowledge_base_chat.py
@@ -1,5 +1,6 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
+from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
@@ -10,55 +11,74 @@ import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
-from server.knowledge_base.utils import get_doc_path
import json
-from pathlib import Path
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
- knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
- top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
- score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2),
- history: List[History] = Body([],
- description="历史对话",
- examples=[[
- {"role": "user",
- "content": "我们来玩成语接龙,我先来,生龙活虎"},
- {"role": "assistant",
- "content": "虎头虎脑"}]]
- ),
- stream: bool = Body(False, description="流式输出"),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
- max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
- prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
- request: Request = None,
- ):
+ knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
+ top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
+ score_threshold: float = Body(
+ SCORE_THRESHOLD,
+ description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
+ ge=0,
+ le=2
+ ),
+ history: List[History] = Body(
+ [],
+ description="历史对话",
+ examples=[[
+ {"role": "user",
+ "content": "我们来玩成语接龙,我先来,生龙活虎"},
+ {"role": "assistant",
+ "content": "虎头虎脑"}]]
+ ),
+ stream: bool = Body(False, description="流式输出"),
+ model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
+ temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ max_tokens: Optional[int] = Body(
+ None,
+ description="限制LLM生成Token数量,默认None代表模型最大值"
+ ),
+ prompt_name: str = Body(
+ "default",
+ description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
+ ),
+ request: Request = None,
+ ):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history]
- async def knowledge_base_chat_iterator(query: str,
- top_k: int,
- history: Optional[List[History]],
- model_name: str = LLM_MODELS[0],
- prompt_name: str = prompt_name,
- ) -> AsyncIterable[str]:
+ async def knowledge_base_chat_iterator(
+ query: str,
+ top_k: int,
+ history: Optional[List[History]],
+ model_name: str = LLM_MODELS[0],
+ prompt_name: str = prompt_name,
+ ) -> AsyncIterable[str]:
+ nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
+ if isinstance(max_tokens, int) and max_tokens <= 0:
+ max_tokens = None
+
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
- docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
+ docs = await run_in_threadpool(search_docs,
+ query=query,
+ knowledge_base_name=knowledge_base_name,
+ top_k=top_k,
+ score_threshold=score_threshold)
context = "\n".join([doc.page_content for doc in docs])
- if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
- prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
+ if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
+ prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
@@ -76,14 +96,14 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
- parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
+ parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
- if len(source_documents) == 0: # 没有找到相关文档
- source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""")
+ if len(source_documents) == 0: # 没有找到相关文档
+ source_documents.append(f"未找到相关文档,该回答为大模型自身能力解答!")
if stream:
async for token in callback.aiter():
@@ -104,4 +124,4 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
history=history,
model_name=model_name,
prompt_name=prompt_name),
- media_type="text/event-stream")
\ No newline at end of file
+ media_type="text/event-stream")
diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py
index 8325b4d9..5b77e993 100644
--- a/server/chat/search_engine_chat.py
+++ b/server/chat/search_engine_chat.py
@@ -147,7 +147,11 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
+ nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
+ if isinstance(max_tokens, int) and max_tokens <= 0:
+ max_tokens = None
+
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
diff --git a/server/db/models/knowledge_metadata_model.py b/server/db/models/knowledge_metadata_model.py
new file mode 100644
index 00000000..03f42009
--- /dev/null
+++ b/server/db/models/knowledge_metadata_model.py
@@ -0,0 +1,28 @@
+from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
+
+from server.db.base import Base
+
+
+class SummaryChunkModel(Base):
+ """
+ chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段,
+ 数据来源:
+ 用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中
+ 程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中
+ 后续任务:
+ 矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids)
+ 语义关联: 通过用户输入的描述,自动切分的总结文本,计算
+ 语义相似度
+
+ """
+ __tablename__ = 'summary_chunk'
+ id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
+ kb_name = Column(String(50), comment='知识库名称')
+ summary_context = Column(String(255), comment='总结文本')
+ summary_id = Column(String(255), comment='总结矢量id')
+ doc_ids = Column(String(1024), comment="向量库id关联列表")
+ meta_data = Column(JSON, default={})
+
+ def __repr__(self):
+ return (f"")
diff --git a/server/db/repository/knowledge_metadata_repository.py b/server/db/repository/knowledge_metadata_repository.py
new file mode 100644
index 00000000..4158e703
--- /dev/null
+++ b/server/db/repository/knowledge_metadata_repository.py
@@ -0,0 +1,66 @@
+from server.db.models.knowledge_metadata_model import SummaryChunkModel
+from server.db.session import with_session
+from typing import List, Dict
+
+
+@with_session
+def list_summary_from_db(session,
+ kb_name: str,
+ metadata: Dict = {},
+ ) -> List[Dict]:
+ '''
+ 列出某知识库chunk summary。
+ 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
+ '''
+ docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
+
+ for k, v in metadata.items():
+ docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v))
+
+ return [{"id": x.id,
+ "summary_context": x.summary_context,
+ "summary_id": x.summary_id,
+ "doc_ids": x.doc_ids,
+ "metadata": x.metadata} for x in docs.all()]
+
+
+@with_session
+def delete_summary_from_db(session,
+ kb_name: str
+ ) -> List[Dict]:
+ '''
+ 删除知识库chunk summary,并返回被删除的Dchunk summary。
+ 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
+ '''
+ docs = list_summary_from_db(kb_name=kb_name)
+ query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
+ query.delete()
+ session.commit()
+ return docs
+
+
+@with_session
+def add_summary_to_db(session,
+ kb_name: str,
+ summary_infos: List[Dict]):
+ '''
+ 将总结信息添加到数据库。
+ summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...]
+ '''
+ for summary in summary_infos:
+ obj = SummaryChunkModel(
+ kb_name=kb_name,
+ summary_context=summary["summary_context"],
+ summary_id=summary["summary_id"],
+ doc_ids=summary["doc_ids"],
+ meta_data=summary["metadata"],
+ )
+ session.add(obj)
+
+ session.commit()
+ return True
+
+
+@with_session
+def count_summary_from_db(session, kb_name: str) -> int:
+ return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count()
diff --git a/server/embeddings_api.py b/server/embeddings_api.py
index 80cd289f..93555a35 100644
--- a/server/embeddings_api.py
+++ b/server/embeddings_api.py
@@ -3,6 +3,7 @@ from configs import EMBEDDING_MODEL, logger
from server.model_workers.base import ApiEmbeddingsParams
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
from fastapi import Body
+from fastapi.concurrency import run_in_threadpool
from typing import Dict, List
@@ -39,6 +40,32 @@ def embed_texts(
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
+
+async def aembed_texts(
+ texts: List[str],
+ embed_model: str = EMBEDDING_MODEL,
+ to_query: bool = False,
+) -> BaseResponse:
+ '''
+ 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
+ '''
+ try:
+ if embed_model in list_embed_models(): # 使用本地Embeddings模型
+ from server.utils import load_local_embeddings
+
+ embeddings = load_local_embeddings(model=embed_model)
+ return BaseResponse(data=await embeddings.aembed_documents(texts))
+
+ if embed_model in list_online_embed_models(): # 使用在线API
+ return await run_in_threadpool(embed_texts,
+ texts=texts,
+ embed_model=embed_model,
+ to_query=to_query)
+ except Exception as e:
+ logger.error(e)
+ return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
+
+
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"),
diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py
index 6823536c..68b54943 100644
--- a/server/knowledge_base/kb_cache/faiss_cache.py
+++ b/server/knowledge_base/kb_cache/faiss_cache.py
@@ -45,7 +45,7 @@ class _FaissPool(CachePool):
# create an empty vector store
embeddings = EmbeddingsFunAdapter(embed_model)
doc = Document(page_content="init", metadata={})
- vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
+ vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
@@ -82,7 +82,7 @@ class KBFaissPool(_FaissPool):
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
- vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
+ vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py
index 0af8de5e..1d86d306 100644
--- a/server/knowledge_base/kb_service/base.py
+++ b/server/knowledge_base/kb_service/base.py
@@ -26,8 +26,8 @@ from server.knowledge_base.utils import (
from typing import List, Union, Dict, Optional
-from server.embeddings_api import embed_texts
-from server.embeddings_api import embed_documents
+from server.embeddings_api import embed_texts, aembed_texts, embed_documents
+from server.knowledge_base.model.kb_document_model import DocumentWithVSId
def normalize(embeddings: List[List[float]]) -> np.ndarray:
@@ -183,12 +183,22 @@ class KBService(ABC):
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []
- def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
+ def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
- docs = self.get_doc_by_ids([x["id"] for x in doc_infos])
+ docs = []
+ for x in doc_infos:
+ doc_info_s = self.get_doc_by_ids([x["id"]])
+ if doc_info_s is not None and doc_info_s != []:
+ # 处理非空的情况
+ doc_with_id = DocumentWithVSId(**doc_info_s[0].dict(), id=x["id"])
+ docs.append(doc_with_id)
+ else:
+ # 处理空的情况
+ # 可以选择跳过当前循环迭代或执行其他操作
+ pass
return docs
@abstractmethod
@@ -394,12 +404,16 @@ class EmbeddingsFunAdapter(Embeddings):
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
- # TODO: 暂不支持异步
- # async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
- # return normalize(await self.embeddings.aembed_documents(texts))
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
+ return normalize(embeddings).tolist()
- # async def aembed_query(self, text: str) -> List[float]:
- # return normalize(await self.embeddings.aembed_query(text))
+ async def aembed_query(self, text: str) -> List[float]:
+ embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
+ query_embed = embeddings[0]
+ query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
+ normalized_query_embed = normalize(query_embed_2d)
+ return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
def score_threshold_process(score_threshold, k, docs):
diff --git a/server/knowledge_base/kb_summary/__init__.py b/server/knowledge_base/kb_summary/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/server/knowledge_base/kb_summary/base.py b/server/knowledge_base/kb_summary/base.py
new file mode 100644
index 00000000..00dcea6f
--- /dev/null
+++ b/server/knowledge_base/kb_summary/base.py
@@ -0,0 +1,79 @@
+from typing import List
+
+from configs import (
+ EMBEDDING_MODEL,
+ KB_ROOT_PATH)
+
+from abc import ABC, abstractmethod
+from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
+import os
+import shutil
+from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db
+
+from langchain.docstore.document import Document
+
+
+# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加
+class KBSummaryService(ABC):
+ kb_name: str
+ embed_model: str
+ vs_path: str
+ kb_path: str
+
+ def __init__(self,
+ knowledge_base_name: str,
+ embed_model: str = EMBEDDING_MODEL
+ ):
+ self.kb_name = knowledge_base_name
+ self.embed_model = embed_model
+
+ self.kb_path = self.get_kb_path()
+ self.vs_path = self.get_vs_path()
+
+ if not os.path.exists(self.vs_path):
+ os.makedirs(self.vs_path)
+
+
+ def get_vs_path(self):
+ return os.path.join(self.get_kb_path(), "summary_vector_store")
+
+ def get_kb_path(self):
+ return os.path.join(KB_ROOT_PATH, self.kb_name)
+
+ def load_vector_store(self) -> ThreadSafeFaiss:
+ return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
+ vector_name="summary_vector_store",
+ embed_model=self.embed_model,
+ create=True)
+
+ def add_kb_summary(self, summary_combine_docs: List[Document]):
+ with self.load_vector_store().acquire() as vs:
+ ids = vs.add_documents(documents=summary_combine_docs)
+ vs.save_local(self.vs_path)
+
+ summary_infos = [{"summary_context": doc.page_content,
+ "summary_id": id,
+ "doc_ids": doc.metadata.get('doc_ids'),
+ "metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)]
+ status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos)
+ return status
+
+ def create_kb_summary(self):
+ """
+ 创建知识库chunk summary
+ :return:
+ """
+
+ if not os.path.exists(self.vs_path):
+ os.makedirs(self.vs_path)
+
+ def drop_kb_summary(self):
+ """
+ 删除知识库chunk summary
+ :param kb_name:
+ :return:
+ """
+ with kb_faiss_pool.atomic:
+ kb_faiss_pool.pop(self.kb_name)
+ shutil.rmtree(self.vs_path)
+ delete_summary_from_db(kb_name=self.kb_name)
diff --git a/server/knowledge_base/kb_summary/summary_chunk.py b/server/knowledge_base/kb_summary/summary_chunk.py
new file mode 100644
index 00000000..0b88f233
--- /dev/null
+++ b/server/knowledge_base/kb_summary/summary_chunk.py
@@ -0,0 +1,247 @@
+from typing import List, Optional
+
+from langchain.schema.language_model import BaseLanguageModel
+
+from server.knowledge_base.model.kb_document_model import DocumentWithVSId
+from configs import (logger)
+from langchain.chains import StuffDocumentsChain, LLMChain
+from langchain.prompts import PromptTemplate
+
+from langchain.docstore.document import Document
+from langchain.output_parsers.regex import RegexParser
+from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
+
+import sys
+import asyncio
+
+
+class SummaryAdapter:
+ _OVERLAP_SIZE: int
+ token_max: int
+ _separator: str = "\n\n"
+ chain: MapReduceDocumentsChain
+
+ def __init__(self, overlap_size: int, token_max: int,
+ chain: MapReduceDocumentsChain):
+ self._OVERLAP_SIZE = overlap_size
+ self.chain = chain
+ self.token_max = token_max
+
+ @classmethod
+ def form_summary(cls,
+ llm: BaseLanguageModel,
+ reduce_llm: BaseLanguageModel,
+ overlap_size: int,
+ token_max: int = 1300):
+ """
+ 获取实例
+ :param reduce_llm: 用于合并摘要的llm
+ :param llm: 用于生成摘要的llm
+ :param overlap_size: 重叠部分大小
+ :param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错
+ :return:
+ """
+
+ # This controls how each document will be formatted. Specifically,
+ document_prompt = PromptTemplate(
+ input_variables=["page_content"],
+ template="{page_content}"
+ )
+
+ # The prompt here should take as an input variable the
+ # `document_variable_name`
+ prompt_template = (
+ "根据文本执行任务。以下任务信息"
+ "{task_briefing}"
+ "文本内容如下: "
+ "\r\n"
+ "{context}"
+ )
+ prompt = PromptTemplate(
+ template=prompt_template,
+ input_variables=["task_briefing", "context"]
+ )
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
+ # We now define how to combine these summaries
+ reduce_prompt = PromptTemplate.from_template(
+ "Combine these summaries: {context}"
+ )
+ reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
+
+ document_variable_name = "context"
+ combine_documents_chain = StuffDocumentsChain(
+ llm_chain=reduce_llm_chain,
+ document_prompt=document_prompt,
+ document_variable_name=document_variable_name
+ )
+ reduce_documents_chain = ReduceDocumentsChain(
+ token_max=token_max,
+ combine_documents_chain=combine_documents_chain,
+ )
+ chain = MapReduceDocumentsChain(
+ llm_chain=llm_chain,
+ document_variable_name=document_variable_name,
+ reduce_documents_chain=reduce_documents_chain,
+ # 返回中间步骤
+ return_intermediate_steps=True
+ )
+ return cls(overlap_size=overlap_size,
+ chain=chain,
+ token_max=token_max)
+
+ def summarize(self,
+ file_description: str,
+ docs: List[DocumentWithVSId] = []
+ ) -> List[Document]:
+
+ if sys.version_info < (3, 10):
+ loop = asyncio.get_event_loop()
+ else:
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
+
+ asyncio.set_event_loop(loop)
+ # 同步调用协程代码
+ return loop.run_until_complete(self.asummarize(file_description=file_description,
+ docs=docs))
+
+ async def asummarize(self,
+ file_description: str,
+ docs: List[DocumentWithVSId] = []) -> List[Document]:
+
+ logger.info("start summary")
+ # TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题
+ # merge_docs = self._drop_overlap(docs)
+ # # 将merge_docs中的句子合并成一个文档
+ # text = self._join_docs(merge_docs)
+ # 根据段落于句子的分隔符,将文档分成chunk,每个chunk长度小于token_max长度
+
+ """
+ 这个过程分成两个部分:
+ 1. 对每个文档进行处理,得到每个文档的摘要
+ map_results = self.llm_chain.apply(
+ # FYI - this is parallelized and so it is fast.
+ [{self.document_variable_name: d.page_content, **kwargs} for d in docs],
+ callbacks=callbacks,
+ )
+ 2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤
+ result, extra_return_dict = self.reduce_documents_chain.combine_docs(
+ result_docs, token_max=token_max, callbacks=callbacks, **kwargs
+ )
+ """
+ summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
+ task_briefing="描述不同方法之间的接近度和相似性,"
+ "以帮助读者理解它们之间的关系。")
+ print(summary_combine)
+ print(summary_intermediate_steps)
+
+ # if len(summary_combine) == 0:
+ # # 为空重新生成,数量减半
+ # result_docs = [
+ # Document(page_content=question_result_key, metadata=docs[i].metadata)
+ # # This uses metadata from the docs, and the textual results from `results`
+ # for i, question_result_key in enumerate(
+ # summary_intermediate_steps["intermediate_steps"][
+ # :len(summary_intermediate_steps["intermediate_steps"]) // 2
+ # ])
+ # ]
+ # summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
+ # result_docs, token_max=self.token_max
+ # )
+ logger.info("end summary")
+ doc_ids = ",".join([doc.id for doc in docs])
+ _metadata = {
+ "file_description": file_description,
+ "summary_intermediate_steps": summary_intermediate_steps,
+ "doc_ids": doc_ids
+ }
+ summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
+
+ return [summary_combine_doc]
+
+ def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
+ """
+ # 将文档中page_content句子叠加的部分去掉
+ :param docs:
+ :param separator:
+ :return:
+ """
+ merge_docs = []
+
+ pre_doc = None
+ for doc in docs:
+ # 第一个文档直接添加
+ if len(merge_docs) == 0:
+ pre_doc = doc.page_content
+ merge_docs.append(doc.page_content)
+ continue
+
+ # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
+ # 迭代递减pre_doc的长度,每次迭代删除前面的字符,
+ # 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
+ for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
+ # 每次迭代删除前面的字符
+ pre_doc = pre_doc[1:]
+ if doc.page_content[:len(pre_doc)] == pre_doc:
+ # 删除下一个开头重叠的部分
+ merge_docs.append(doc.page_content[len(pre_doc):])
+ break
+
+ pre_doc = doc.page_content
+
+ return merge_docs
+
+ def _join_docs(self, docs: List[str]) -> Optional[str]:
+ text = self._separator.join(docs)
+ text = text.strip()
+ if text == "":
+ return None
+ else:
+ return text
+
+
+if __name__ == '__main__':
+
+ docs = [
+
+ '梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
+
+ '梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
+
+ '使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
+ '值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
+ ]
+ _OVERLAP_SIZE = 1
+ separator: str = "\n\n"
+ merge_docs = []
+ # 将文档中page_content句子叠加的部分去掉,
+ # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
+ pre_doc = None
+ for doc in docs:
+ # 第一个文档直接添加
+ if len(merge_docs) == 0:
+ pre_doc = doc
+ merge_docs.append(doc)
+ continue
+
+ # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
+ # 迭代递减pre_doc的长度,每次迭代删除前面的字符,
+ # 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
+ for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
+ # 每次迭代删除前面的字符
+ pre_doc = pre_doc[1:]
+ if doc[:len(pre_doc)] == pre_doc:
+ # 删除下一个开头重叠的部分
+ page_content = doc[len(pre_doc):]
+ merge_docs.append(page_content)
+
+ pre_doc = doc
+ break
+
+ # 将merge_docs中的句子合并成一个文档
+ text = separator.join(merge_docs)
+ text = text.strip()
+
+ print(text)
diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py
new file mode 100644
index 00000000..aac4de78
--- /dev/null
+++ b/server/knowledge_base/kb_summary_api.py
@@ -0,0 +1,220 @@
+from fastapi import Body
+from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
+ OVERLAP_SIZE,
+ logger, log_verbose, )
+from server.knowledge_base.utils import (list_files_from_folder)
+from fastapi.responses import StreamingResponse
+import json
+from server.knowledge_base.kb_service.base import KBServiceFactory
+from typing import List, Optional
+from server.knowledge_base.kb_summary.base import KBSummaryService
+from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
+from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
+from configs import LLM_MODELS, TEMPERATURE
+from server.knowledge_base.model.kb_document_model import DocumentWithVSId
+
+def recreate_summary_vector_store(
+ knowledge_base_name: str = Body(..., examples=["samples"]),
+ allow_empty_kb: bool = Body(True),
+ vs_type: str = Body(DEFAULT_VS_TYPE),
+ embed_model: str = Body(EMBEDDING_MODEL),
+ file_description: str = Body(''),
+ model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
+ temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
+):
+ """
+ 重建单个知识库文件摘要
+ :param max_tokens:
+ :param model_name:
+ :param temperature:
+ :param file_description:
+ :param knowledge_base_name:
+ :param allow_empty_kb:
+ :param vs_type:
+ :param embed_model:
+ :return:
+ """
+
+ def output():
+
+ kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
+ if not kb.exists() and not allow_empty_kb:
+ yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
+ else:
+ # 重新创建知识库
+ kb_summary = KBSummaryService(knowledge_base_name, embed_model)
+ kb_summary.drop_kb_summary()
+ kb_summary.create_kb_summary()
+
+ llm = get_ChatOpenAI(
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ reduce_llm = get_ChatOpenAI(
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ # 文本摘要适配器
+ summary = SummaryAdapter.form_summary(llm=llm,
+ reduce_llm=reduce_llm,
+ overlap_size=OVERLAP_SIZE)
+ files = list_files_from_folder(knowledge_base_name)
+
+ i = 0
+ for i, file_name in enumerate(files):
+
+ doc_infos = kb.list_docs(file_name=file_name)
+ docs = summary.summarize(file_description=file_description,
+ docs=doc_infos)
+
+ status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
+ if status_kb_summary:
+ logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
+ yield json.dumps({
+ "code": 200,
+ "msg": f"({i + 1} / {len(files)}): {file_name}",
+ "total": len(files),
+ "finished": i + 1,
+ "doc": file_name,
+ }, ensure_ascii=False)
+ else:
+
+ msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
+ logger.error(msg)
+ yield json.dumps({
+ "code": 500,
+ "msg": msg,
+ })
+ i += 1
+
+ return StreamingResponse(output(), media_type="text/event-stream")
+
+
+def summary_file_to_vector_store(
+ knowledge_base_name: str = Body(..., examples=["samples"]),
+ file_name: str = Body(..., examples=["test.pdf"]),
+ allow_empty_kb: bool = Body(True),
+ vs_type: str = Body(DEFAULT_VS_TYPE),
+ embed_model: str = Body(EMBEDDING_MODEL),
+ file_description: str = Body(''),
+ model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
+ temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
+):
+ """
+ 单个知识库根据文件名称摘要
+ :param model_name:
+ :param max_tokens:
+ :param temperature:
+ :param file_description:
+ :param file_name:
+ :param knowledge_base_name:
+ :param allow_empty_kb:
+ :param vs_type:
+ :param embed_model:
+ :return:
+ """
+
+ def output():
+ kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
+ if not kb.exists() and not allow_empty_kb:
+ yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
+ else:
+ # 重新创建知识库
+ kb_summary = KBSummaryService(knowledge_base_name, embed_model)
+ kb_summary.create_kb_summary()
+
+ llm = get_ChatOpenAI(
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ reduce_llm = get_ChatOpenAI(
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ # 文本摘要适配器
+ summary = SummaryAdapter.form_summary(llm=llm,
+ reduce_llm=reduce_llm,
+ overlap_size=OVERLAP_SIZE)
+
+ doc_infos = kb.list_docs(file_name=file_name)
+ docs = summary.summarize(file_description=file_description,
+ docs=doc_infos)
+
+ status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
+ if status_kb_summary:
+ logger.info(f" {file_name} 总结完成")
+ yield json.dumps({
+ "code": 200,
+ "msg": f"{file_name} 总结完成",
+ "doc": file_name,
+ }, ensure_ascii=False)
+ else:
+
+ msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
+ logger.error(msg)
+ yield json.dumps({
+ "code": 500,
+ "msg": msg,
+ })
+
+ return StreamingResponse(output(), media_type="text/event-stream")
+
+
+def summary_doc_ids_to_vector_store(
+ knowledge_base_name: str = Body(..., examples=["samples"]),
+ doc_ids: List = Body([], examples=[["uuid"]]),
+ vs_type: str = Body(DEFAULT_VS_TYPE),
+ embed_model: str = Body(EMBEDDING_MODEL),
+ file_description: str = Body(''),
+ model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
+ temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
+) -> BaseResponse:
+ """
+ 单个知识库根据doc_ids摘要
+ :param knowledge_base_name:
+ :param doc_ids:
+ :param model_name:
+ :param max_tokens:
+ :param temperature:
+ :param file_description:
+ :param vs_type:
+ :param embed_model:
+ :return:
+ """
+ kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
+ if not kb.exists():
+ return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
+ else:
+ llm = get_ChatOpenAI(
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ reduce_llm = get_ChatOpenAI(
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ # 文本摘要适配器
+ summary = SummaryAdapter.form_summary(llm=llm,
+ reduce_llm=reduce_llm,
+ overlap_size=OVERLAP_SIZE)
+
+ doc_infos = kb.get_doc_by_ids(ids=doc_ids)
+ # doc_infos转换成DocumentWithVSId包装的对象
+ doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)]
+
+ docs = summary.summarize(file_description=file_description,
+ docs=doc_info_with_ids)
+
+ # 将docs转换成dict
+ resp_summarize = [{**doc.dict()} for doc in docs]
+
+ return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize})
diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py
index b3601476..bde6e1f8 100644
--- a/server/knowledge_base/migrate.py
+++ b/server/knowledge_base/migrate.py
@@ -12,6 +12,8 @@ from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.models.conversation_model import ConversationModel
from server.db.models.message_model import MessageModel
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
+from server.db.repository.knowledge_metadata_repository import add_summary_to_db
+
from server.db.base import Base, engine
from server.db.session import session_scope
import os
diff --git a/server/knowledge_base/model/kb_document_model.py b/server/knowledge_base/model/kb_document_model.py
new file mode 100644
index 00000000..a5d2c6ab
--- /dev/null
+++ b/server/knowledge_base/model/kb_document_model.py
@@ -0,0 +1,10 @@
+
+from langchain.docstore.document import Document
+
+
+class DocumentWithVSId(Document):
+ """
+ 矢量化后的文档
+ """
+ id: str = None
+
diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py
index d5587ea5..41c3ad5a 100644
--- a/server/knowledge_base/utils.py
+++ b/server/knowledge_base/utils.py
@@ -71,7 +71,7 @@ def list_files_from_folder(kb_name: str):
for target_entry in target_it:
process_entry(target_entry)
elif entry.is_file():
- result.append(entry.path)
+ result.append(os.path.relpath(entry.path, doc_path))
elif entry.is_dir():
with os.scandir(entry.path) as it:
for sub_entry in it:
@@ -85,6 +85,7 @@ def list_files_from_folder(kb_name: str):
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
+ "MHTMLLoader": ['.mhtml'],
"UnstructuredMarkdownLoader": ['.md'],
"JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"],
@@ -106,6 +107,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredWordDocumentLoader": ['.docx', 'doc'],
"UnstructuredXMLLoader": ['.xml'],
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
+ "EverNoteLoader": ['.enex'],
"UnstructuredFileLoader": ['.txt'],
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
@@ -311,6 +313,9 @@ class KnowledgeFile:
else:
docs = text_splitter.split_documents(docs)
+ if not docs:
+ return []
+
print(f"文档切分示例:{docs[0]}")
if zh_title_enhance:
docs = func_zh_title_enhance(docs)
diff --git a/server/model_workers/azure.py b/server/model_workers/azure.py
index c6b4cbba..70959325 100644
--- a/server/model_workers/azure.py
+++ b/server/model_workers/azure.py
@@ -1,4 +1,5 @@
import sys
+import os
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
@@ -19,16 +20,16 @@ class AzureWorker(ApiModelWorker):
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
- kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
+
data = dict(
messages=params.messages,
temperature=params.temperature,
- max_tokens=params.max_tokens,
+ max_tokens=params.max_tokens if params.max_tokens else None,
stream=True,
)
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
@@ -47,6 +48,7 @@ class AzureWorker(ApiModelWorker):
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
+ print(data)
for line in response.iter_lines():
if not line.strip() or "[DONE]" in line:
continue
@@ -60,6 +62,7 @@ class AzureWorker(ApiModelWorker):
"error_code": 0,
"text": text
}
+ print(text)
else:
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
@@ -91,4 +94,4 @@ if __name__ == "__main__":
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
- uvicorn.run(app, port=21008)
+ uvicorn.run(app, port=21008)
\ No newline at end of file
diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py
index 45e9a5f9..947782a8 100644
--- a/server/model_workers/minimax.py
+++ b/server/model_workers/minimax.py
@@ -119,7 +119,9 @@ class MiniMaxWorker(ApiModelWorker):
with get_httpx_client() as client:
result = []
i = 0
- for texts in params.texts[i:i+10]:
+ batch_size = 10
+ while i < len(params.texts):
+ texts = params.texts[i:i+batch_size]
data["texts"] = texts
r = client.post(url, headers=headers, json=data).json()
if embeddings := r.get("vectors"):
@@ -137,7 +139,7 @@ class MiniMaxWorker(ApiModelWorker):
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
- i += 10
+ i += batch_size
return {"code": 200, "data": embeddings}
def get_embeddings(self, params):
diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py
index 95e200ac..2bcce94e 100644
--- a/server/model_workers/qianfan.py
+++ b/server/model_workers/qianfan.py
@@ -188,9 +188,11 @@ class QianFanWorker(ApiModelWorker):
with get_httpx_client() as client:
result = []
i = 0
- for texts in params.texts[i:i+10]:
+ batch_size = 10
+ while i < len(params.texts):
+ texts = params.texts[i:i+batch_size]
resp = client.post(url, json={"input": texts}).json()
- if "error_cdoe" in resp:
+ if "error_code" in resp:
data = {
"code": resp["error_code"],
"msg": resp["error_msg"],
@@ -206,7 +208,7 @@ class QianFanWorker(ApiModelWorker):
else:
embeddings = [x["embedding"] for x in resp.get("data", [])]
result += embeddings
- i += 10
+ i += batch_size
return {"code": 200, "data": result}
# TODO: qianfan支持续写模型
diff --git a/server/utils.py b/server/utils.py
index c51c6e86..21b1baf0 100644
--- a/server/utils.py
+++ b/server/utils.py
@@ -40,11 +40,10 @@ def get_ChatOpenAI(
verbose: bool = True,
**kwargs: Any,
) -> ChatOpenAI:
- ## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
- config_models = list_config_llm_models()
-
- ## 非Langchain原生支持的模型,走Fschat封装
config = get_model_worker_config(model_name)
+ if model_name == "openai-api":
+ model_name = config.get("model_name")
+
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
@@ -57,10 +56,8 @@ def get_ChatOpenAI(
openai_proxy=config.get("openai_proxy"),
**kwargs
)
-
return model
-
def get_OpenAI(
model_name: str,
temperature: float,
@@ -71,67 +68,22 @@ def get_OpenAI(
verbose: bool = True,
**kwargs: Any,
) -> OpenAI:
- ## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
- config_models = list_config_llm_models()
- if model_name in config_models.get("langchain", {}):
- config = config_models["langchain"][model_name]
- if model_name == "Azure-OpenAI":
- model = AzureOpenAI(
- streaming=streaming,
- verbose=verbose,
- callbacks=callbacks,
- deployment_name=config.get("deployment_name"),
- model_version=config.get("model_version"),
- openai_api_type=config.get("openai_api_type"),
- openai_api_base=config.get("api_base_url"),
- openai_api_version=config.get("api_version"),
- openai_api_key=config.get("api_key"),
- openai_proxy=config.get("openai_proxy"),
- temperature=temperature,
- max_tokens=max_tokens,
- echo=echo,
- )
-
- elif model_name == "OpenAI":
- model = OpenAI(
- streaming=streaming,
- verbose=verbose,
- callbacks=callbacks,
- model_name=config.get("model_name"),
- openai_api_base=config.get("api_base_url"),
- openai_api_key=config.get("api_key"),
- openai_proxy=config.get("openai_proxy"),
- temperature=temperature,
- max_tokens=max_tokens,
- echo=echo,
- )
- elif model_name == "Anthropic":
- model = Anthropic(
- streaming=streaming,
- verbose=verbose,
- callbacks=callbacks,
- model_name=config.get("model_name"),
- anthropic_api_key=config.get("api_key"),
- echo=echo,
- )
- ## TODO 支持其他的Langchain原生支持的模型
- else:
- ## 非Langchain原生支持的模型,走Fschat封装
- config = get_model_worker_config(model_name)
- model = OpenAI(
- streaming=streaming,
- verbose=verbose,
- callbacks=callbacks,
- openai_api_key=config.get("api_key", "EMPTY"),
- openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
- model_name=model_name,
- temperature=temperature,
- max_tokens=max_tokens,
- openai_proxy=config.get("openai_proxy"),
- echo=echo,
- **kwargs
- )
-
+ config = get_model_worker_config(model_name)
+ if model_name == "openai-api":
+ model_name = config.get("model_name")
+ model = OpenAI(
+ streaming=streaming,
+ verbose=verbose,
+ callbacks=callbacks,
+ openai_api_key=config.get("api_key", "EMPTY"),
+ openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ openai_proxy=config.get("openai_proxy"),
+ echo=echo,
+ **kwargs
+ )
return model
@@ -630,7 +582,7 @@ def get_httpx_client(
for host in os.environ.get("no_proxy", "").split(","):
if host := host.strip():
# default_proxies.update({host: None}) # Origin code
- default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error
+ default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error
# merge default proxies with user provided proxies
if isinstance(proxies, str):
@@ -714,7 +666,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
from configs.basic_config import BASE_TEMP_DIR
import tempfile
- if id is not None: # 如果指定的临时目录已存在,直接返回
+ if id is not None: # 如果指定的临时目录已存在,直接返回
path = os.path.join(BASE_TEMP_DIR, id)
if os.path.isdir(path):
return path, id
diff --git a/startup.py b/startup.py
index e610cda1..3bb508a0 100644
--- a/startup.py
+++ b/startup.py
@@ -36,6 +36,7 @@ from server.utils import (fschat_controller_address, fschat_model_worker_address
fschat_openai_api_address, set_httpx_config, get_httpx_client,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
+from server.knowledge_base.migrate import create_tables
import argparse
from typing import Tuple, List, Dict
from configs import VERSION
@@ -866,6 +867,8 @@ async def start_main_server():
logger.info("Process status: %s", p)
if __name__ == "__main__":
+ # 确保数据库表被创建
+ create_tables()
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
diff --git a/tests/api/test_kb_summary_api.py b/tests/api/test_kb_summary_api.py
new file mode 100644
index 00000000..d59c2036
--- /dev/null
+++ b/tests/api/test_kb_summary_api.py
@@ -0,0 +1,44 @@
+import requests
+import json
+import sys
+from pathlib import Path
+
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+from server.utils import api_address
+
+api_base_url = api_address()
+
+kb = "samples"
+file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md"
+doc_ids = [
+ "357d580f-fdf7-495c-b58b-595a398284e8",
+ "c7338773-2e83-4671-b237-1ad20335b0f0",
+ "6da613d1-327d-466f-8c1a-b32e6f461f47"
+]
+
+
+def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"):
+ url = api_base_url + api
+ print("\n文件摘要:")
+ r = requests.post(url, json={"knowledge_base_name": kb,
+ "file_name": file_name
+ }, stream=True)
+ for chunk in r.iter_content(None):
+ data = json.loads(chunk)
+ assert isinstance(data, dict)
+ assert data["code"] == 200
+ print(data["msg"])
+
+
+def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"):
+ url = api_base_url + api
+ print("\n文件摘要:")
+ r = requests.post(url, json={"knowledge_base_name": kb,
+ "doc_ids": doc_ids
+ }, stream=True)
+ for chunk in r.iter_content(None):
+ data = json.loads(chunk)
+ assert isinstance(data, dict)
+ assert data["code"] == 200
+ print(data)
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index 3310fb3d..064c8a42 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -1,8 +1,11 @@
import streamlit as st
from webui_pages.utils import *
from streamlit_chatbox import *
+from streamlit_modal import Modal
from datetime import datetime
import os
+import re
+import time
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from server.knowledge_base.utils import LOADER_DICT
@@ -47,6 +50,53 @@ def upload_temp_docs(files, _api: ApiRequest) -> str:
return _api.upload_temp_docs(files).get("data", {}).get("id")
+def parse_command(text: str, modal: Modal) -> bool:
+ '''
+ 检查用户是否输入了自定义命令,当前支持:
+ /new {session_name}。如果未提供名称,默认为“会话X”
+ /del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。
+ /clear {session_name}。如果未提供名称,默认清除当前会话
+ /help。查看命令帮助
+ 返回值:输入的是命令返回True,否则返回False
+ '''
+ if m := re.match(r"/([^\s]+)\s*(.*)", text):
+ cmd, name = m.groups()
+ name = name.strip()
+ conv_names = chat_box.get_chat_names()
+ if cmd == "help":
+ modal.open()
+ elif cmd == "new":
+ if not name:
+ i = 1
+ while True:
+ name = f"会话{i}"
+ if name not in conv_names:
+ break
+ i += 1
+ if name in st.session_state["conversation_ids"]:
+ st.error(f"该会话名称 “{name}” 已存在")
+ time.sleep(1)
+ else:
+ st.session_state["conversation_ids"][name] = uuid.uuid4().hex
+ st.session_state["cur_conv_name"] = name
+ elif cmd == "del":
+ name = name or st.session_state.get("cur_conv_name")
+ if len(conv_names) == 1:
+ st.error("这是最后一个会话,无法删除")
+ time.sleep(1)
+ elif not name or name not in st.session_state["conversation_ids"]:
+ st.error(f"无效的会话名称:“{name}”")
+ time.sleep(1)
+ else:
+ st.session_state["conversation_ids"].pop(name, None)
+ chat_box.del_chat_name(name)
+ st.session_state["cur_conv_name"] = ""
+ elif cmd == "clear":
+ chat_box.reset_history(name=name or None)
+ return True
+ return False
+
+
def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.session_state.setdefault("conversation_ids", {})
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
@@ -60,25 +110,15 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
)
chat_box.init_session()
+ # 弹出自定义命令帮助信息
+ modal = Modal("自定义命令", key="cmd_help", max_width="500")
+ if modal.is_open():
+ with modal.container():
+ cmds = [x for x in parse_command.__doc__.split("\n") if x.strip().startswith("/")]
+ st.write("\n\n".join(cmds))
+
with st.sidebar:
# 多会话
- cols = st.columns([3, 1])
- conv_name = cols[0].text_input("会话名称")
- with cols[1]:
- if st.button("添加"):
- if not conv_name or conv_name in st.session_state["conversation_ids"]:
- st.error("请指定有效的会话名称")
- else:
- st.session_state["conversation_ids"][conv_name] = uuid.uuid4().hex
- st.session_state["cur_conv_name"] = conv_name
- st.session_state["conv_name"] = ""
- if st.button("删除"):
- if not conv_name or conv_name not in st.session_state["conversation_ids"]:
- st.error("请指定有效的会话名称")
- else:
- st.session_state["conversation_ids"].pop(conv_name, None)
- st.session_state["cur_conv_name"] = ""
-
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
if st.session_state.get("cur_conv_name") in conv_names:
@@ -236,7 +276,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
# Display chat messages from history on app rerun
chat_box.output_messages()
- chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter "
+ chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
def on_feedback(
feedback,
@@ -256,139 +296,142 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
}
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
- history = get_messages_history(history_len)
- chat_box.user_say(prompt)
- if dialogue_mode == "LLM 对话":
- chat_box.ai_say("正在思考...")
- text = ""
- message_id = ""
- r = api.chat_chat(prompt,
- history=history,
- conversation_id=conversation_id,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature)
- for t in r:
- if error_msg := check_error_msg(t): # check whether error occured
- st.error(error_msg)
- break
- text += t.get("text", "")
- chat_box.update_msg(text)
- message_id = t.get("message_id", "")
+ if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
+ st.rerun()
+ else:
+ history = get_messages_history(history_len)
+ chat_box.user_say(prompt)
+ if dialogue_mode == "LLM 对话":
+ chat_box.ai_say("正在思考...")
+ text = ""
+ message_id = ""
+ r = api.chat_chat(prompt,
+ history=history,
+ conversation_id=conversation_id,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature)
+ for t in r:
+ if error_msg := check_error_msg(t): # check whether error occured
+ st.error(error_msg)
+ break
+ text += t.get("text", "")
+ chat_box.update_msg(text)
+ message_id = t.get("message_id", "")
- metadata = {
- "message_id": message_id,
- }
- chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
- chat_box.show_feedback(**feedback_kwargs,
- key=message_id,
- on_submit=on_feedback,
- kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
+ metadata = {
+ "message_id": message_id,
+ }
+ chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
+ chat_box.show_feedback(**feedback_kwargs,
+ key=message_id,
+ on_submit=on_feedback,
+ kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
- elif dialogue_mode == "自定义Agent问答":
- if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
+ elif dialogue_mode == "自定义Agent问答":
+ if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
+ chat_box.ai_say([
+ f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n",
+ Markdown("...", in_expander=True, title="思考过程", state="complete"),
+
+ ])
+ else:
+ chat_box.ai_say([
+ f"正在思考...",
+ Markdown("...", in_expander=True, title="思考过程", state="complete"),
+
+ ])
+ text = ""
+ ans = ""
+ for d in api.agent_chat(prompt,
+ history=history,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature,
+ ):
+ try:
+ d = json.loads(d)
+ except:
+ pass
+ if error_msg := check_error_msg(d): # check whether error occured
+ st.error(error_msg)
+ if chunk := d.get("answer"):
+ text += chunk
+ chat_box.update_msg(text, element_index=1)
+ if chunk := d.get("final_answer"):
+ ans += chunk
+ chat_box.update_msg(ans, element_index=0)
+ if chunk := d.get("tools"):
+ text += "\n\n".join(d.get("tools", []))
+ chat_box.update_msg(text, element_index=1)
+ chat_box.update_msg(ans, element_index=0, streaming=False)
+ chat_box.update_msg(text, element_index=1, streaming=False)
+ elif dialogue_mode == "知识库问答":
chat_box.ai_say([
- f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n",
- Markdown("...", in_expander=True, title="思考过程", state="complete"),
-
+ f"正在查询知识库 `{selected_kb}` ...",
+ Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
])
- else:
+ text = ""
+ for d in api.knowledge_base_chat(prompt,
+ knowledge_base_name=selected_kb,
+ top_k=kb_top_k,
+ score_threshold=score_threshold,
+ history=history,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature):
+ if error_msg := check_error_msg(d): # check whether error occured
+ st.error(error_msg)
+ elif chunk := d.get("answer"):
+ text += chunk
+ chat_box.update_msg(text, element_index=0)
+ chat_box.update_msg(text, element_index=0, streaming=False)
+ chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
+ elif dialogue_mode == "文件对话":
+ if st.session_state["file_chat_id"] is None:
+ st.error("请先上传文件再进行对话")
+ st.stop()
chat_box.ai_say([
- f"正在思考...",
- Markdown("...", in_expander=True, title="思考过程", state="complete"),
-
+ f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
+ Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
])
- text = ""
- ans = ""
- for d in api.agent_chat(prompt,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature,
- ):
- try:
- d = json.loads(d)
- except:
- pass
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- if chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=1)
- if chunk := d.get("final_answer"):
- ans += chunk
- chat_box.update_msg(ans, element_index=0)
- if chunk := d.get("tools"):
- text += "\n\n".join(d.get("tools", []))
- chat_box.update_msg(text, element_index=1)
- chat_box.update_msg(ans, element_index=0, streaming=False)
- chat_box.update_msg(text, element_index=1, streaming=False)
- elif dialogue_mode == "知识库问答":
- chat_box.ai_say([
- f"正在查询知识库 `{selected_kb}` ...",
- Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
- ])
- text = ""
- for d in api.knowledge_base_chat(prompt,
- knowledge_base_name=selected_kb,
- top_k=kb_top_k,
- score_threshold=score_threshold,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature):
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- elif chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=0)
- chat_box.update_msg(text, element_index=0, streaming=False)
- chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
- elif dialogue_mode == "文件对话":
- if st.session_state["file_chat_id"] is None:
- st.error("请先上传文件再进行对话")
- st.stop()
- chat_box.ai_say([
- f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
- Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
- ])
- text = ""
- for d in api.file_chat(prompt,
- knowledge_id=st.session_state["file_chat_id"],
- top_k=kb_top_k,
- score_threshold=score_threshold,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature):
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- elif chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=0)
- chat_box.update_msg(text, element_index=0, streaming=False)
- chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
- elif dialogue_mode == "搜索引擎问答":
- chat_box.ai_say([
- f"正在执行 `{search_engine}` 搜索...",
- Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
- ])
- text = ""
- for d in api.search_engine_chat(prompt,
- search_engine_name=search_engine,
- top_k=se_top_k,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature,
- split_result=se_top_k > 1):
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- elif chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=0)
- chat_box.update_msg(text, element_index=0, streaming=False)
- chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
+ text = ""
+ for d in api.file_chat(prompt,
+ knowledge_id=st.session_state["file_chat_id"],
+ top_k=kb_top_k,
+ score_threshold=score_threshold,
+ history=history,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature):
+ if error_msg := check_error_msg(d): # check whether error occured
+ st.error(error_msg)
+ elif chunk := d.get("answer"):
+ text += chunk
+ chat_box.update_msg(text, element_index=0)
+ chat_box.update_msg(text, element_index=0, streaming=False)
+ chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
+ elif dialogue_mode == "搜索引擎问答":
+ chat_box.ai_say([
+ f"正在执行 `{search_engine}` 搜索...",
+ Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
+ ])
+ text = ""
+ for d in api.search_engine_chat(prompt,
+ search_engine_name=search_engine,
+ top_k=se_top_k,
+ history=history,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature,
+ split_result=se_top_k > 1):
+ if error_msg := check_error_msg(d): # check whether error occured
+ st.error(error_msg)
+ elif chunk := d.get("answer"):
+ text += chunk
+ chat_box.update_msg(text, element_index=0)
+ chat_box.update_msg(text, element_index=0, streaming=False)
+ chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
if st.session_state.get("need_rerun"):
st.session_state["need_rerun"] = False
diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py
index 57d6202d..53484917 100644
--- a/webui_pages/knowledge_base/knowledge_base.py
+++ b/webui_pages/knowledge_base/knowledge_base.py
@@ -34,14 +34,19 @@ def config_aggrid(
use_checkbox=use_checkbox,
# pre_selected_rows=st.session_state.get("selected_rows", [0]),
)
+ gb.configure_pagination(
+ enabled=True,
+ paginationAutoPageSize=False,
+ paginationPageSize=10
+ )
return gb
def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]:
- '''
+ """
check whether a doc file exists in local knowledge base folder.
return the file's name and path if it exists.
- '''
+ """
if selected_rows:
file_name = selected_rows[0]["file_name"]
file_path = get_file_path(kb, file_name)
diff --git a/webui_pages/utils.py b/webui_pages/utils.py
index 64401447..75e14199 100644
--- a/webui_pages/utils.py
+++ b/webui_pages/utils.py
@@ -85,7 +85,7 @@ class ApiRequest:
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
- print(kwargs)
+ # print(kwargs)
if stream:
return self.client.stream("POST", url, data=data, json=json, **kwargs)
else:
@@ -134,7 +134,7 @@ class ApiRequest:
if as_json:
try:
data = json.loads(chunk)
- pprint(data, depth=1)
+ # pprint(data, depth=1)
yield data
except Exception as e:
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
@@ -166,7 +166,7 @@ class ApiRequest:
if as_json:
try:
data = json.loads(chunk)
- pprint(data, depth=1)
+ # pprint(data, depth=1)
yield data
except Exception as e:
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
@@ -276,8 +276,8 @@ class ApiRequest:
"max_tokens": max_tokens,
}
- print(f"received input message:")
- pprint(data)
+ # print(f"received input message:")
+ # pprint(data)
response = self.post(
"/chat/fastchat",
@@ -288,23 +288,25 @@ class ApiRequest:
return self._httpx_stream2generator(response)
def chat_chat(
- self,
- query: str,
- conversation_id: str = None,
- history: List[Dict] = [],
- stream: bool = True,
- model: str = LLM_MODELS[0],
- temperature: float = TEMPERATURE,
- max_tokens: int = None,
- prompt_name: str = "default",
- **kwargs,
+ self,
+ query: str,
+ conversation_id: str = None,
+ history_len: int = -1,
+ history: List[Dict] = [],
+ stream: bool = True,
+ model: str = LLM_MODELS[0],
+ temperature: float = TEMPERATURE,
+ max_tokens: int = None,
+ prompt_name: str = "default",
+ **kwargs,
):
'''
- 对应api.py/chat/chat接口 #TODO: 考虑是否返回json
+ 对应api.py/chat/chat接口
'''
data = {
"query": query,
"conversation_id": conversation_id,
+ "history_len": history_len,
"history": history,
"stream": stream,
"model_name": model,
@@ -313,8 +315,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
- print(f"received input message:")
- pprint(data)
+ # print(f"received input message:")
+ # pprint(data)
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response, as_json=True)
@@ -342,8 +344,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
- print(f"received input message:")
- pprint(data)
+ # print(f"received input message:")
+ # pprint(data)
response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response)
@@ -377,8 +379,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
- print(f"received input message:")
- pprint(data)
+ # print(f"received input message:")
+ # pprint(data)
response = self.post(
"/chat/knowledge_base_chat",
@@ -452,8 +454,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
- print(f"received input message:")
- pprint(data)
+ # print(f"received input message:")
+ # pprint(data)
response = self.post(
"/chat/file_chat",
@@ -491,8 +493,8 @@ class ApiRequest:
"split_result": split_result,
}
- print(f"received input message:")
- pprint(data)
+ # print(f"received input message:")
+ # pprint(data)
response = self.post(
"/chat/search_engine_chat",