mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-28 17:53:33 +08:00
Merge branch 'dev' of https://github.com/chatchat-space/Langchain-Chatchat into dev
This commit is contained in:
commit
4c2fda7200
@ -148,7 +148,7 @@ $ python startup.py -a
|
||||
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
||||
|
||||
### 项目交流群
|
||||
<img src="img/qr_code_74.jpg" alt="二维码" width="300" />
|
||||
<img src="img/qr_code_76.jpg" alt="二维码" width="300" />
|
||||
|
||||
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||
|
||||
|
||||
@ -5,4 +5,4 @@ from .server_config import *
|
||||
from .prompt_config import *
|
||||
|
||||
|
||||
VERSION = "v0.2.8-preview"
|
||||
VERSION = "v0.2.9-preview"
|
||||
|
||||
@ -106,7 +106,7 @@ kbs_config = {
|
||||
# TextSplitter配置项,如果你不明白其中的含义,就不要修改。
|
||||
text_splitter_dict = {
|
||||
"ChineseRecursiveTextSplitter": {
|
||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||
"source": "huggingface", # 选择tiktoken则使用openai的方法
|
||||
"tokenizer_name_or_path": "",
|
||||
},
|
||||
"SpacyTextSplitter": {
|
||||
|
||||
@ -15,9 +15,11 @@ EMBEDDING_DEVICE = "auto"
|
||||
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
||||
EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
||||
|
||||
# 要运行的 LLM 名称,可以包括本地模型和在线模型。
|
||||
# 第一个将作为 API 和 WEBUI 的默认模型
|
||||
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
|
||||
# 要运行的 LLM 名称,可以包括本地模型和在线模型。列表中本地模型将在启动项目时全部加载。
|
||||
# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
|
||||
# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
|
||||
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
|
||||
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",
|
||||
|
||||
# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0])
|
||||
Agent_MODEL = None
|
||||
@ -112,10 +114,10 @@ ONLINE_LLM_MODEL = {
|
||||
"api_key": "",
|
||||
"provider": "AzureWorker",
|
||||
},
|
||||
|
||||
|
||||
# 昆仑万维天工 API https://model-platform.tiangong.cn/
|
||||
"tiangong-api": {
|
||||
"version":"SkyChat-MegaVerse",
|
||||
"version": "SkyChat-MegaVerse",
|
||||
"api_key": "",
|
||||
"secret_key": "",
|
||||
"provider": "TianGongWorker",
|
||||
@ -163,6 +165,25 @@ MODEL_PATH = {
|
||||
|
||||
"chatglm3-6b": "THUDM/chatglm3-6b",
|
||||
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
||||
"chatglm3-6b-base": "THUDM/chatglm3-6b-base",
|
||||
|
||||
"Qwen-1_8B": "Qwen/Qwen-1_8B",
|
||||
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
||||
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
|
||||
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
|
||||
|
||||
"Qwen-72B": "Qwen/Qwen-72B",
|
||||
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
||||
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
|
||||
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
|
||||
|
||||
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
|
||||
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat",
|
||||
@ -204,18 +225,11 @@ MODEL_PATH = {
|
||||
"opt-66b": "facebook/opt-66b",
|
||||
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
||||
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", # 确保已经安装了auto-gptq optimum flash-attn
|
||||
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4", # 确保已经安装了auto-gptq optimum flash-attn
|
||||
|
||||
"agentlm-7b": "THUDM/agentlm-7b",
|
||||
"agentlm-13b": "THUDM/agentlm-13b",
|
||||
"agentlm-70b": "THUDM/agentlm-70b",
|
||||
|
||||
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", # 更多01-ai模型尚未进行测试。如果需要使用,请自行测试。
|
||||
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
|
||||
},
|
||||
}
|
||||
|
||||
@ -242,11 +256,11 @@ VLLM_MODEL_DICT = {
|
||||
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
|
||||
|
||||
# 注意:bloom系列的tokenizer与model是分离的,因此虽然vllm支持,但与fschat框架不兼容
|
||||
# "bloom":"bigscience/bloom",
|
||||
# "bloomz":"bigscience/bloomz",
|
||||
# "bloomz-560m":"bigscience/bloomz-560m",
|
||||
# "bloomz-7b1":"bigscience/bloomz-7b1",
|
||||
# "bloomz-1b7":"bigscience/bloomz-1b7",
|
||||
# "bloom": "bigscience/bloom",
|
||||
# "bloomz": "bigscience/bloomz",
|
||||
# "bloomz-560m": "bigscience/bloomz-560m",
|
||||
# "bloomz-7b1": "bigscience/bloomz-7b1",
|
||||
# "bloomz-1b7": "bigscience/bloomz-1b7",
|
||||
|
||||
"internlm-7b": "internlm/internlm-7b",
|
||||
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
||||
@ -273,10 +287,23 @@ VLLM_MODEL_DICT = {
|
||||
"opt-66b": "facebook/opt-66b",
|
||||
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
||||
|
||||
"Qwen-1_8B": "Qwen/Qwen-1_8B",
|
||||
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
||||
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
|
||||
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
|
||||
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
|
||||
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
|
||||
|
||||
"Qwen-72B": "Qwen/Qwen-72B",
|
||||
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
||||
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
|
||||
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
|
||||
|
||||
"agentlm-7b": "THUDM/agentlm-7b",
|
||||
"agentlm-13b": "THUDM/agentlm-13b",
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||
# 本配置文件支持热加载,修改prompt模板后无需重启服务。
|
||||
|
||||
|
||||
# LLM对话支持的变量:
|
||||
# - input: 用户输入内容
|
||||
|
||||
@ -17,125 +16,112 @@
|
||||
# - input: 用户输入内容
|
||||
# - agent_scratchpad: Agent的思维记录
|
||||
|
||||
PROMPT_TEMPLATES = {}
|
||||
PROMPT_TEMPLATES = {
|
||||
"llm_chat": {
|
||||
"default":
|
||||
'{{ input }}',
|
||||
|
||||
PROMPT_TEMPLATES["llm_chat"] = {
|
||||
"default": "{{ input }}",
|
||||
"with_history":
|
||||
"""The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
|
||||
"with_history":
|
||||
'The following is a friendly conversation between a human and an AI. '
|
||||
'The AI is talkative and provides lots of specific details from its context. '
|
||||
'If the AI does not know the answer to a question, it truthfully says it does not know.\n\n'
|
||||
'Current conversation:\n'
|
||||
'{history}\n'
|
||||
'Human: {input}\n'
|
||||
'AI:',
|
||||
|
||||
Current conversation:
|
||||
{history}
|
||||
Human: {input}
|
||||
AI:""",
|
||||
"py":
|
||||
"""
|
||||
你是一个聪明的代码助手,请你给我写出简单的py代码。 \n
|
||||
{{ input }}
|
||||
""",
|
||||
}
|
||||
|
||||
PROMPT_TEMPLATES["knowledge_base_chat"] = {
|
||||
"default":
|
||||
"""
|
||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
"text":
|
||||
"""
|
||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
"Empty": # 搜不到知识库的时候使用
|
||||
"""
|
||||
请你回答我的问题:
|
||||
{{ question }}
|
||||
\n
|
||||
""",
|
||||
}
|
||||
PROMPT_TEMPLATES["search_engine_chat"] = {
|
||||
"default":
|
||||
"""
|
||||
<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
"search":
|
||||
"""
|
||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
|
||||
<已知信息>{{ context }}</已知信息>、
|
||||
<问题>{{ question }}</问题>
|
||||
""",
|
||||
}
|
||||
PROMPT_TEMPLATES["agent_chat"] = {
|
||||
"default":
|
||||
"""
|
||||
Answer the following questions as best you can. If it is in order, you can use some tools appropriately.You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
Please note that the "知识库查询工具" is information about the "西交利物浦大学" ,and if a question is asked about it, you must answer with the knowledge base,
|
||||
Please note that the "天气查询工具" can only be used once since Question begin.
|
||||
|
||||
Use the following format:
|
||||
Question: the input question you must answer1
|
||||
Thought: you should always think about what to do and what tools to use.
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
Begin!
|
||||
|
||||
history: {history}
|
||||
|
||||
Question: {input}
|
||||
|
||||
Thought: {agent_scratchpad}
|
||||
""",
|
||||
|
||||
"ChatGLM3":
|
||||
"""
|
||||
You can answer using the tools, or answer directly using your knowledge without using the tools.Respond to the human as helpfully and accurately as possible.
|
||||
You have access to the following tools:
|
||||
{tools}
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: "Final Answer" or [{tool_names}]
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
|
||||
history: {history}
|
||||
|
||||
Question: {input}
|
||||
|
||||
Thought: {agent_scratchpad}
|
||||
""",
|
||||
"py":
|
||||
'你是一个聪明的代码助手,请你给我写出简单的py代码。 \n'
|
||||
'{{ input }}',
|
||||
},
|
||||
|
||||
|
||||
"knowledge_base_chat": {
|
||||
"default":
|
||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,'
|
||||
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
|
||||
"text":
|
||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
|
||||
"empty": # 搜不到知识库的时候使用
|
||||
'请你回答我的问题:\n'
|
||||
'{{ question }}\n\n',
|
||||
},
|
||||
|
||||
|
||||
"search_engine_chat": {
|
||||
"default":
|
||||
'<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。'
|
||||
'如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 </指令>\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
|
||||
"search":
|
||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
},
|
||||
|
||||
|
||||
"agent_chat": {
|
||||
"default":
|
||||
'Answer the following questions as best you can. If it is in order, you can use some tools appropriately. '
|
||||
'You have access to the following tools:\n\n'
|
||||
'{tools}\n\n'
|
||||
'Use the following format:\n'
|
||||
'Question: the input question you must answer1\n'
|
||||
'Thought: you should always think about what to do and what tools to use.\n'
|
||||
'Action: the action to take, should be one of [{tool_names}]\n'
|
||||
'Action Input: the input to the action\n'
|
||||
'Observation: the result of the action\n'
|
||||
'... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n'
|
||||
'Thought: I now know the final answer\n'
|
||||
'Final Answer: the final answer to the original input question\n'
|
||||
'Begin!\n\n'
|
||||
'history: {history}\n\n'
|
||||
'Question: {input}\n\n'
|
||||
'Thought: {agent_scratchpad}\n',
|
||||
|
||||
"ChatGLM3":
|
||||
'You can answer using the tools, or answer directly using your knowledge without using the tools. '
|
||||
'Respond to the human as helpfully and accurately as possible.\n'
|
||||
'You have access to the following tools:\n'
|
||||
'{tools}\n'
|
||||
'Use a json blob to specify a tool by providing an action key (tool name) '
|
||||
'and an action_input key (tool input).\n'
|
||||
'Valid "action" values: "Final Answer" or [{tool_names}]'
|
||||
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
|
||||
'```\n'
|
||||
'{{{{\n'
|
||||
' "action": $TOOL_NAME,\n'
|
||||
' "action_input": $INPUT\n'
|
||||
'}}}}\n'
|
||||
'```\n\n'
|
||||
'Follow this format:\n\n'
|
||||
'Question: input question to answer\n'
|
||||
'Thought: consider previous and subsequent steps\n'
|
||||
'Action:\n'
|
||||
'```\n'
|
||||
'$JSON_BLOB\n'
|
||||
'```\n'
|
||||
'Observation: action result\n'
|
||||
'... (repeat Thought/Action/Observation N times)\n'
|
||||
'Thought: I know what to respond\n'
|
||||
'Action:\n'
|
||||
'```\n'
|
||||
'{{{{\n'
|
||||
' "action": "Final Answer",\n'
|
||||
' "action_input": "Final response to human"\n'
|
||||
'}}}}\n'
|
||||
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. '
|
||||
'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n'
|
||||
'history: {history}\n\n'
|
||||
'Question: {input}\n\n'
|
||||
'Thought: {agent_scratchpad}',
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ FSCHAT_MODEL_WORKERS = {
|
||||
# "no_register": False,
|
||||
# "embed_in_truncate": False,
|
||||
|
||||
# 以下为vllm_woker配置参数,注意使用vllm必须有gpu,仅在Linux测试通过
|
||||
# 以下为vllm_worker配置参数,注意使用vllm必须有gpu,仅在Linux测试通过
|
||||
|
||||
# tokenizer = model_path # 如果tokenizer与model_path不一致在此处添加
|
||||
# 'tokenizer_mode':'auto',
|
||||
@ -93,14 +93,14 @@ FSCHAT_MODEL_WORKERS = {
|
||||
|
||||
},
|
||||
# 可以如下示例方式更改默认配置
|
||||
# "Qwen-7B-Chat": { # 使用default中的IP和端口
|
||||
# "Qwen-1_8B-Chat": { # 使用default中的IP和端口
|
||||
# "device": "cpu",
|
||||
# },
|
||||
"chatglm3-6b": { # 使用default中的IP和端口
|
||||
"chatglm3-6b": { # 使用default中的IP和端口
|
||||
"device": "cuda",
|
||||
},
|
||||
|
||||
#以下配置可以不用修改,在model_config中设置启动的模型
|
||||
# 以下配置可以不用修改,在model_config中设置启动的模型
|
||||
"zhipu-api": {
|
||||
"port": 21001,
|
||||
},
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from typing import List
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from document_loaders.ocr import get_ocr
|
||||
|
||||
|
||||
class RapidOCRLoader(UnstructuredFileLoader):
|
||||
def _get_elements(self) -> List:
|
||||
def img2text(filepath):
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
resp = ""
|
||||
ocr = RapidOCR()
|
||||
ocr = get_ocr()
|
||||
result, _ = ocr(filepath)
|
||||
if result:
|
||||
ocr_result = [line[1] for line in result]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import List
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from document_loaders.ocr import get_ocr
|
||||
import tqdm
|
||||
|
||||
|
||||
@ -7,9 +8,8 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
|
||||
def _get_elements(self) -> List:
|
||||
def pdf2text(filepath):
|
||||
import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
import numpy as np
|
||||
ocr = RapidOCR()
|
||||
ocr = get_ocr()
|
||||
doc = fitz.open(filepath)
|
||||
resp = ""
|
||||
|
||||
|
||||
18
document_loaders/ocr.py
Normal file
18
document_loaders/ocr.py
Normal file
@ -0,0 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from rapidocr_paddle import RapidOCR
|
||||
except ImportError:
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
|
||||
|
||||
def get_ocr(use_cuda: bool = True) -> "RapidOCR":
|
||||
try:
|
||||
from rapidocr_paddle import RapidOCR
|
||||
ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda)
|
||||
except ImportError:
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
ocr = RapidOCR()
|
||||
return ocr
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 267 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 198 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 270 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 174 KiB |
BIN
img/qr_code_76.jpg
Normal file
BIN
img/qr_code_76.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 182 KiB |
@ -1 +1 @@
|
||||
Subproject commit f789e5dde10f91136012f3470c020c8d34572436
|
||||
Subproject commit 9a3fa7a77f8748748b1c656fe8919ad5c4c63e3f
|
||||
@ -1,63 +1,75 @@
|
||||
# API requirements
|
||||
|
||||
langchain>=0.0.334
|
||||
langchain==0.0.344
|
||||
langchain-experimental>=0.0.42
|
||||
fschat[model_worker]>=0.2.33
|
||||
pydantic==1.10.13
|
||||
fschat>=0.2.33
|
||||
xformers>=0.0.22.post7
|
||||
openai~=0.28.1
|
||||
openai>=1.3.7
|
||||
sentence_transformers
|
||||
transformers>=4.35.2
|
||||
torch==2.1.0 ##on win, install the cuda version manually if you want use gpu
|
||||
torchvision #on win, install the cuda version manually if you want use gpu
|
||||
torchaudio #on win, install the cuda version manually if you want use gpu
|
||||
torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/
|
||||
torchvision #on Windows system, install the cuda version manually from https://pytorch.org/
|
||||
torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/
|
||||
fastapi>=0.104
|
||||
nltk>=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
uvicorn>=0.24.0.post1
|
||||
starlette~=0.27.0
|
||||
pydantic<2
|
||||
unstructured[all-docs]>=0.11.0
|
||||
unstructured[all-docs]==0.11.0
|
||||
python-magic-bin; sys_platform == 'win32'
|
||||
SQLAlchemy==2.0.19
|
||||
faiss-cpu
|
||||
accelerate
|
||||
spacy
|
||||
faiss-cpu # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
|
||||
accelerate>=0.24.1
|
||||
spacy>=3.7.2
|
||||
PyMuPDF
|
||||
rapidocr_onnxruntime
|
||||
|
||||
requests
|
||||
pathlib
|
||||
pytest
|
||||
numexpr
|
||||
strsimpy
|
||||
markdownify
|
||||
tiktoken
|
||||
tqdm
|
||||
requests>=2.31.0
|
||||
pathlib>=1.0.1
|
||||
pytest>=7.4.3
|
||||
numexpr>=2.8.7
|
||||
strsimpy>=0.2.1
|
||||
markdownify>=0.11.6
|
||||
tiktoken>=0.5.1
|
||||
tqdm>=4.66.1
|
||||
websockets
|
||||
numpy~=1.24.4
|
||||
pandas~=2.0.3
|
||||
einops
|
||||
einops>=0.7.0
|
||||
transformers_stream_generator==0.0.4
|
||||
|
||||
vllm==0.2.2; sys_platform == "linux"
|
||||
|
||||
# online api libs dependencies
|
||||
# optional document loaders
|
||||
# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files
|
||||
# jq # for .json and .jsonl files. suggest `conda install jq` on windows
|
||||
# html2text # for .enex files
|
||||
# beautifulsoup4 # for .mhtml files
|
||||
# pysrt # for .srt files
|
||||
|
||||
# Online api libs dependencies
|
||||
|
||||
# zhipuai>=1.0.7
|
||||
# dashscope>=1.10.0
|
||||
# qianfan>=0.2.0
|
||||
# volcengine>=1.0.106
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
# pymilvus>=2.3.3
|
||||
# psycopg2
|
||||
# pgvector
|
||||
# pgvector>=0.2.4
|
||||
|
||||
# Agent and Search Tools
|
||||
|
||||
arxiv>=2.0.0
|
||||
youtube-search>=2.1.2
|
||||
duckduckgo-search>=3.9.9
|
||||
metaphor-python>=0.1.23
|
||||
|
||||
# WebUI requirements
|
||||
|
||||
streamlit~=1.28.2 # # on win, make sure write its path in environment variable
|
||||
streamlit>=1.29.0 # do remember to add streamlit to environment variables if you use windows
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.2.3
|
||||
streamlit-chatbox>=1.1.11
|
||||
streamlit-modal>=0.1.0
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx[brotli,http2,socks]~=0.24.1
|
||||
watchdog
|
||||
httpx[brotli,http2,socks]>=0.25.2
|
||||
watchdog>=3.0.0
|
||||
|
||||
|
||||
@ -1,52 +1,65 @@
|
||||
# API requirements
|
||||
|
||||
langchain>=0.0.334
|
||||
langchain==0.0.344
|
||||
langchain-experimental>=0.0.42
|
||||
fschat[model_worker]>=0.2.33
|
||||
pydantic==1.10.13
|
||||
fschat>=0.2.33
|
||||
xformers>=0.0.22.post7
|
||||
openai~=0.28.1
|
||||
openai>=1.3.7
|
||||
sentence_transformers
|
||||
transformers>=4.35.2
|
||||
torch==2.1.0
|
||||
torchvision
|
||||
torchaudio
|
||||
torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/
|
||||
torchvision #on Windows system, install the cuda version manually from https://pytorch.org/
|
||||
torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/
|
||||
fastapi>=0.104
|
||||
nltk>=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
uvicorn>=0.24.0.post1
|
||||
starlette~=0.27.0
|
||||
pydantic<2
|
||||
unstructured[all-docs]>=0.11.0
|
||||
unstructured[all-docs]==0.11.0
|
||||
python-magic-bin; sys_platform == 'win32'
|
||||
SQLAlchemy==2.0.19
|
||||
faiss-cpu
|
||||
accelerate>=0.24.1
|
||||
spacy
|
||||
spacy>=3.7.2
|
||||
PyMuPDF
|
||||
rapidocr_onnxruntime
|
||||
|
||||
requests
|
||||
pathlib
|
||||
pytest
|
||||
numexpr
|
||||
strsimpy
|
||||
markdownify
|
||||
tiktoken
|
||||
tqdm
|
||||
requests>=2.31.0
|
||||
pathlib>=1.0.1
|
||||
pytest>=7.4.3
|
||||
numexpr>=2.8.7
|
||||
strsimpy>=0.2.1
|
||||
markdownify>=0.11.6
|
||||
tiktoken>=0.5.1
|
||||
tqdm>=4.66.1
|
||||
websockets
|
||||
numpy~=1.24.4
|
||||
pandas~=2.0.3
|
||||
einops
|
||||
transformers_stream_generator>=0.0.4
|
||||
einops>=0.7.0
|
||||
transformers_stream_generator==0.0.4
|
||||
vllm==0.2.2; sys_platform == "linux"
|
||||
httpx[brotli,http2,socks]>=0.25.2
|
||||
|
||||
vllm>=0.2.0; sys_platform == "linux"
|
||||
|
||||
# online api libs
|
||||
zhipuai
|
||||
dashscope>=1.10.0 # qwen
|
||||
qianfan
|
||||
# volcengine>=1.0.106 # fangzhou
|
||||
# optional document loaders
|
||||
# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files
|
||||
# jq # for .json and .jsonl files. suggest `conda install jq` on windows
|
||||
# html2text # for .enex files
|
||||
# beautifulsoup4 # for .mhtml files
|
||||
# pysrt # for .srt files
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
# Online api libs dependencies
|
||||
|
||||
# zhipuai>=1.0.7
|
||||
# dashscope>=1.10.0
|
||||
# qianfan>=0.2.0
|
||||
# volcengine>=1.0.106
|
||||
# pymilvus>=2.3.3
|
||||
# psycopg2
|
||||
# pgvector
|
||||
# pgvector>=0.2.4
|
||||
|
||||
# Agent and Search Tools
|
||||
|
||||
arxiv>=2.0.0
|
||||
youtube-search>=2.1.2
|
||||
duckduckgo-search>=3.9.9
|
||||
metaphor-python>=0.1.23
|
||||
@ -1,62 +1,70 @@
|
||||
langchain>=0.0.334
|
||||
fschat>=0.2.32
|
||||
openai
|
||||
# sentence_transformers
|
||||
# transformers>=4.35.0
|
||||
# torch>=2.0.1
|
||||
# torchvision
|
||||
# torchaudio
|
||||
langchain==0.0.344
|
||||
pydantic==1.10.13
|
||||
fschat>=0.2.33
|
||||
openai>=1.3.7
|
||||
fastapi>=0.104.1
|
||||
python-multipart
|
||||
nltk~=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
uvicorn>=0.24.0.post1
|
||||
starlette~=0.27.0
|
||||
pydantic~=1.10.11
|
||||
unstructured[docx,csv]>=0.10.4 # add pdf if need
|
||||
unstructured[docx,csv]==0.11.0 # add pdf if need
|
||||
python-magic-bin; sys_platform == 'win32'
|
||||
SQLAlchemy==2.0.19
|
||||
numexpr>=2.8.7
|
||||
strsimpy>=0.2.1
|
||||
|
||||
faiss-cpu
|
||||
# accelerate
|
||||
# spacy
|
||||
# accelerate>=0.24.1
|
||||
# spacy>=3.7.2
|
||||
# PyMuPDF==1.22.5 # install if need pdf
|
||||
# rapidocr_onnxruntime>=1.3.2 # install if need pdf
|
||||
|
||||
# optional document loaders
|
||||
# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files
|
||||
# jq # for .json and .jsonl files. suggest `conda install jq` on windows
|
||||
# html2text # for .enex files
|
||||
# beautifulsoup4 # for .mhtml files
|
||||
# pysrt # for .srt files
|
||||
|
||||
requests
|
||||
pathlib
|
||||
pytest
|
||||
# scikit-learn
|
||||
# numexpr
|
||||
# vllm==0.1.7; sys_platform == "linux"
|
||||
# vllm==0.2.2; sys_platform == "linux"
|
||||
|
||||
# online api libs
|
||||
zhipuai
|
||||
dashscope>=1.10.0 # qwen
|
||||
# qianfan
|
||||
|
||||
zhipuai>=1.0.7 # zhipu
|
||||
# dashscope>=1.10.0 # qwen
|
||||
# volcengine>=1.0.106 # fangzhou
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
# psycopg2
|
||||
# pgvector
|
||||
# pgvector>=0.2.4
|
||||
|
||||
numpy~=1.24.4
|
||||
pandas~=2.0.3
|
||||
streamlit~=1.28.1
|
||||
streamlit>=1.29.0
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox==1.1.11
|
||||
streamlit-antd-components>=0.2.3
|
||||
streamlit-chatbox>=1.1.11
|
||||
streamlit-modal>=0.1.0
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
watchdog
|
||||
tqdm
|
||||
httpx[brotli,http2,socks]>=0.25.2
|
||||
watchdog>=3.0.0
|
||||
tqdm>=4.66.1
|
||||
websockets
|
||||
einops>=0.7.0
|
||||
|
||||
# tiktoken
|
||||
einops
|
||||
# scipy
|
||||
# scipy>=1.11.4
|
||||
# transformers_stream_generator==0.0.4
|
||||
|
||||
# search engine libs
|
||||
duckduckgo-search
|
||||
metaphor-python
|
||||
strsimpy
|
||||
markdownify
|
||||
# Agent and Search Tools
|
||||
|
||||
arxiv>=2.0.0
|
||||
youtube-search>=2.1.2
|
||||
duckduckgo-search>=3.9.9
|
||||
metaphor-python>=0.1.23
|
||||
@ -1,10 +1,10 @@
|
||||
# WebUI requirements
|
||||
|
||||
streamlit~=1.28.2
|
||||
streamlit>=1.29.0
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.2.3
|
||||
streamlit-chatbox>=1.1.11
|
||||
streamlit-modal>=0.1.0
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx[brotli,http2,socks]~=0.24.1
|
||||
watchdog
|
||||
|
||||
httpx[brotli,http2,socks]>=0.25.2
|
||||
watchdog>=3.0.0
|
||||
|
||||
@ -170,7 +170,6 @@ class LLMKnowledgeChain(LLMChain):
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
except:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
print(queries)
|
||||
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
|
||||
output = self._evaluate_expression(queries)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
|
||||
@ -81,6 +81,8 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
|
||||
# 知识库相关接口
|
||||
mount_knowledge_routes(app)
|
||||
# 摘要相关接口
|
||||
mount_filename_summary_routes(app)
|
||||
|
||||
# LLM模型相关接口
|
||||
app.post("/llm_model/list_running_models",
|
||||
@ -230,6 +232,26 @@ def mount_knowledge_routes(app: FastAPI):
|
||||
)(upload_temp_docs)
|
||||
|
||||
|
||||
def mount_filename_summary_routes(app: FastAPI):
|
||||
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
|
||||
summary_doc_ids_to_vector_store)
|
||||
|
||||
app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store",
|
||||
tags=["Knowledge kb_summary_api Management"],
|
||||
summary="单个知识库根据文件名称摘要"
|
||||
)(summary_file_to_vector_store)
|
||||
app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store",
|
||||
tags=["Knowledge kb_summary_api Management"],
|
||||
summary="单个知识库根据doc_ids摘要",
|
||||
response_model=BaseResponse,
|
||||
)(summary_doc_ids_to_vector_store)
|
||||
app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store",
|
||||
tags=["Knowledge kb_summary_api Management"],
|
||||
summary="重建单个知识库文件摘要"
|
||||
)(recreate_summary_vector_store)
|
||||
|
||||
|
||||
|
||||
def run_api(host, port, **kwargs):
|
||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||
uvicorn.run(app,
|
||||
|
||||
@ -43,7 +43,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
||||
model_name: str = LLM_MODELS[0],
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = CustomAsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
|
||||
@ -19,6 +19,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
|
||||
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
conversation_id: str = Body("", description="对话框ID"),
|
||||
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||||
history: Union[int, List[History]] = Body([],
|
||||
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||||
examples=[[
|
||||
@ -34,18 +35,20 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
nonlocal history
|
||||
nonlocal history, max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
memory = None
|
||||
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||
# 负责保存llm response到message db
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
||||
chat_type="llm_chat",
|
||||
query=query)
|
||||
callbacks.append(conversation_callback)
|
||||
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
@ -54,18 +57,24 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if not conversation_id:
|
||||
if history: # 优先使用前端传入的历史消息
|
||||
history = [History.from_data(h) for h in history]
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
else:
|
||||
elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
|
||||
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
|
||||
prompt = get_prompt_template("llm_chat", "with_history")
|
||||
chat_prompt = PromptTemplate.from_template(prompt)
|
||||
# 根据conversation_id 获取message 列表进而拼凑 memory
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model)
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id,
|
||||
llm=model,
|
||||
message_limit=history_len)
|
||||
else:
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
|
||||
|
||||
|
||||
@ -6,7 +6,8 @@ from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable, Optional
|
||||
import asyncio
|
||||
from langchain.prompts.chat import PromptTemplate
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from server.utils import get_prompt_template
|
||||
|
||||
|
||||
@ -27,7 +28,11 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||
prompt_name: str = prompt_name,
|
||||
echo: bool = echo,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_OpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
|
||||
@ -113,7 +113,11 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
@ -121,7 +125,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
callbacks=[callback],
|
||||
)
|
||||
embed_func = EmbeddingsFunAdapter()
|
||||
embeddings = embed_func.embed_query(query)
|
||||
embeddings = await embed_func.aembed_query(query)
|
||||
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
docs = [x[0] for x in docs]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
@ -10,55 +11,74 @@ import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.knowledge_base.utils import get_doc_path
|
||||
import json
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlencode
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
request: Request = None,
|
||||
):
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(
|
||||
SCORE_THRESHOLD,
|
||||
description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
|
||||
ge=0,
|
||||
le=2
|
||||
),
|
||||
history: List[History] = Body(
|
||||
[],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(
|
||||
None,
|
||||
description="限制LLM生成Token数量,默认None代表模型最大值"
|
||||
),
|
||||
prompt_name: str = Body(
|
||||
"default",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
|
||||
),
|
||||
request: Request = None,
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODELS[0],
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
async def knowledge_base_chat_iterator(
|
||||
query: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODELS[0],
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
docs = await run_in_threadpool(search_docs,
|
||||
query=query,
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
|
||||
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||
else:
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
@ -76,14 +96,14 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||
source_documents = []
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = doc.metadata.get("source")
|
||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
|
||||
base_url = request.base_url
|
||||
url = f"{base_url}knowledge_base/download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
if len(source_documents) == 0: # 没有找到相关文档
|
||||
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
||||
if len(source_documents) == 0: # 没有找到相关文档
|
||||
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
@ -104,4 +124,4 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||
history=history,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
media_type="text/event-stream")
|
||||
media_type="text/event-stream")
|
||||
|
||||
@ -147,7 +147,11 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||
model_name: str = LLM_MODELS[0],
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
|
||||
28
server/db/models/knowledge_metadata_model.py
Normal file
28
server/db/models/knowledge_metadata_model.py
Normal file
@ -0,0 +1,28 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
|
||||
|
||||
from server.db.base import Base
|
||||
|
||||
|
||||
class SummaryChunkModel(Base):
|
||||
"""
|
||||
chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段,
|
||||
数据来源:
|
||||
用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中
|
||||
程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中
|
||||
后续任务:
|
||||
矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids)
|
||||
语义关联: 通过用户输入的描述,自动切分的总结文本,计算
|
||||
语义相似度
|
||||
|
||||
"""
|
||||
__tablename__ = 'summary_chunk'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
|
||||
kb_name = Column(String(50), comment='知识库名称')
|
||||
summary_context = Column(String(255), comment='总结文本')
|
||||
summary_id = Column(String(255), comment='总结矢量id')
|
||||
doc_ids = Column(String(1024), comment="向量库id关联列表")
|
||||
meta_data = Column(JSON, default={})
|
||||
|
||||
def __repr__(self):
|
||||
return (f"<SummaryChunk(id='{self.id}', kb_name='{self.kb_name}', summary_context='{self.summary_context}',"
|
||||
f" doc_ids='{self.doc_ids}', metadata='{self.metadata}')>")
|
||||
66
server/db/repository/knowledge_metadata_repository.py
Normal file
66
server/db/repository/knowledge_metadata_repository.py
Normal file
@ -0,0 +1,66 @@
|
||||
from server.db.models.knowledge_metadata_model import SummaryChunkModel
|
||||
from server.db.session import with_session
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
@with_session
|
||||
def list_summary_from_db(session,
|
||||
kb_name: str,
|
||||
metadata: Dict = {},
|
||||
) -> List[Dict]:
|
||||
'''
|
||||
列出某知识库chunk summary。
|
||||
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
|
||||
'''
|
||||
docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
|
||||
|
||||
for k, v in metadata.items():
|
||||
docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v))
|
||||
|
||||
return [{"id": x.id,
|
||||
"summary_context": x.summary_context,
|
||||
"summary_id": x.summary_id,
|
||||
"doc_ids": x.doc_ids,
|
||||
"metadata": x.metadata} for x in docs.all()]
|
||||
|
||||
|
||||
@with_session
|
||||
def delete_summary_from_db(session,
|
||||
kb_name: str
|
||||
) -> List[Dict]:
|
||||
'''
|
||||
删除知识库chunk summary,并返回被删除的Dchunk summary。
|
||||
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
|
||||
'''
|
||||
docs = list_summary_from_db(kb_name=kb_name)
|
||||
query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
|
||||
query.delete()
|
||||
session.commit()
|
||||
return docs
|
||||
|
||||
|
||||
@with_session
|
||||
def add_summary_to_db(session,
|
||||
kb_name: str,
|
||||
summary_infos: List[Dict]):
|
||||
'''
|
||||
将总结信息添加到数据库。
|
||||
summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...]
|
||||
'''
|
||||
for summary in summary_infos:
|
||||
obj = SummaryChunkModel(
|
||||
kb_name=kb_name,
|
||||
summary_context=summary["summary_context"],
|
||||
summary_id=summary["summary_id"],
|
||||
doc_ids=summary["doc_ids"],
|
||||
meta_data=summary["metadata"],
|
||||
)
|
||||
session.add(obj)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def count_summary_from_db(session, kb_name: str) -> int:
|
||||
return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count()
|
||||
@ -3,6 +3,7 @@ from configs import EMBEDDING_MODEL, logger
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
|
||||
from fastapi import Body
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@ -39,6 +40,32 @@ def embed_texts(
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
async def aembed_texts(
|
||||
texts: List[str],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
try:
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
from server.utils import load_local_embeddings
|
||||
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
||||
|
||||
if embed_model in list_online_embed_models(): # 使用在线API
|
||||
return await run_in_threadpool(embed_texts,
|
||||
texts=texts,
|
||||
embed_model=embed_model,
|
||||
to_query=to_query)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
def embed_texts_endpoint(
|
||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"),
|
||||
|
||||
@ -45,7 +45,7 @@ class _FaissPool(CachePool):
|
||||
# create an empty vector store
|
||||
embeddings = EmbeddingsFunAdapter(embed_model)
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
vector_store.delete(ids)
|
||||
return vector_store
|
||||
@ -82,7 +82,7 @@ class KBFaissPool(_FaissPool):
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
elif create:
|
||||
# create an empty vector store
|
||||
if not os.path.exists(vs_path):
|
||||
|
||||
@ -26,8 +26,8 @@ from server.knowledge_base.utils import (
|
||||
|
||||
from typing import List, Union, Dict, Optional
|
||||
|
||||
from server.embeddings_api import embed_texts
|
||||
from server.embeddings_api import embed_documents
|
||||
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
|
||||
|
||||
def normalize(embeddings: List[List[float]]) -> np.ndarray:
|
||||
@ -183,12 +183,22 @@ class KBService(ABC):
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
return []
|
||||
|
||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
|
||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
|
||||
'''
|
||||
通过file_name或metadata检索Document
|
||||
'''
|
||||
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||
docs = self.get_doc_by_ids([x["id"] for x in doc_infos])
|
||||
docs = []
|
||||
for x in doc_infos:
|
||||
doc_info_s = self.get_doc_by_ids([x["id"]])
|
||||
if doc_info_s is not None and doc_info_s != []:
|
||||
# 处理非空的情况
|
||||
doc_with_id = DocumentWithVSId(**doc_info_s[0].dict(), id=x["id"])
|
||||
docs.append(doc_with_id)
|
||||
else:
|
||||
# 处理空的情况
|
||||
# 可以选择跳过当前循环迭代或执行其他操作
|
||||
pass
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
@ -394,12 +404,16 @@ class EmbeddingsFunAdapter(Embeddings):
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
# TODO: 暂不支持异步
|
||||
# async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
# return normalize(await self.embeddings.aembed_documents(texts))
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
|
||||
return normalize(embeddings).tolist()
|
||||
|
||||
# async def aembed_query(self, text: str) -> List[float]:
|
||||
# return normalize(await self.embeddings.aembed_query(text))
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
|
||||
def score_threshold_process(score_threshold, k, docs):
|
||||
|
||||
0
server/knowledge_base/kb_summary/__init__.py
Normal file
0
server/knowledge_base/kb_summary/__init__.py
Normal file
79
server/knowledge_base/kb_summary/base.py
Normal file
79
server/knowledge_base/kb_summary/base.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import List
|
||||
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
import os
|
||||
import shutil
|
||||
from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加
|
||||
class KBSummaryService(ABC):
|
||||
kb_name: str
|
||||
embed_model: str
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.embed_model = embed_model
|
||||
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
|
||||
def get_vs_path(self):
|
||||
return os.path.join(self.get_kb_path(), "summary_vector_store")
|
||||
|
||||
def get_kb_path(self):
|
||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||
|
||||
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
||||
vector_name="summary_vector_store",
|
||||
embed_model=self.embed_model,
|
||||
create=True)
|
||||
|
||||
def add_kb_summary(self, summary_combine_docs: List[Document]):
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = vs.add_documents(documents=summary_combine_docs)
|
||||
vs.save_local(self.vs_path)
|
||||
|
||||
summary_infos = [{"summary_context": doc.page_content,
|
||||
"summary_id": id,
|
||||
"doc_ids": doc.metadata.get('doc_ids'),
|
||||
"metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)]
|
||||
status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos)
|
||||
return status
|
||||
|
||||
def create_kb_summary(self):
|
||||
"""
|
||||
创建知识库chunk summary
|
||||
:return:
|
||||
"""
|
||||
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
def drop_kb_summary(self):
|
||||
"""
|
||||
删除知识库chunk summary
|
||||
:param kb_name:
|
||||
:return:
|
||||
"""
|
||||
with kb_faiss_pool.atomic:
|
||||
kb_faiss_pool.pop(self.kb_name)
|
||||
shutil.rmtree(self.vs_path)
|
||||
delete_summary_from_db(kb_name=self.kb_name)
|
||||
247
server/knowledge_base/kb_summary/summary_chunk.py
Normal file
247
server/knowledge_base/kb_summary/summary_chunk.py
Normal file
@ -0,0 +1,247 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from configs import (logger)
|
||||
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
|
||||
class SummaryAdapter:
|
||||
_OVERLAP_SIZE: int
|
||||
token_max: int
|
||||
_separator: str = "\n\n"
|
||||
chain: MapReduceDocumentsChain
|
||||
|
||||
def __init__(self, overlap_size: int, token_max: int,
|
||||
chain: MapReduceDocumentsChain):
|
||||
self._OVERLAP_SIZE = overlap_size
|
||||
self.chain = chain
|
||||
self.token_max = token_max
|
||||
|
||||
@classmethod
|
||||
def form_summary(cls,
|
||||
llm: BaseLanguageModel,
|
||||
reduce_llm: BaseLanguageModel,
|
||||
overlap_size: int,
|
||||
token_max: int = 1300):
|
||||
"""
|
||||
获取实例
|
||||
:param reduce_llm: 用于合并摘要的llm
|
||||
:param llm: 用于生成摘要的llm
|
||||
:param overlap_size: 重叠部分大小
|
||||
:param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错
|
||||
:return:
|
||||
"""
|
||||
|
||||
# This controls how each document will be formatted. Specifically,
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"],
|
||||
template="{page_content}"
|
||||
)
|
||||
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt_template = (
|
||||
"根据文本执行任务。以下任务信息"
|
||||
"{task_briefing}"
|
||||
"文本内容如下: "
|
||||
"\r\n"
|
||||
"{context}"
|
||||
)
|
||||
prompt = PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["task_briefing", "context"]
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
# We now define how to combine these summaries
|
||||
reduce_prompt = PromptTemplate.from_template(
|
||||
"Combine these summaries: {context}"
|
||||
)
|
||||
reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
|
||||
|
||||
document_variable_name = "context"
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
token_max=token_max,
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
)
|
||||
chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
# 返回中间步骤
|
||||
return_intermediate_steps=True
|
||||
)
|
||||
return cls(overlap_size=overlap_size,
|
||||
chain=chain,
|
||||
token_max=token_max)
|
||||
|
||||
def summarize(self,
|
||||
file_description: str,
|
||||
docs: List[DocumentWithVSId] = []
|
||||
) -> List[Document]:
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
# 同步调用协程代码
|
||||
return loop.run_until_complete(self.asummarize(file_description=file_description,
|
||||
docs=docs))
|
||||
|
||||
async def asummarize(self,
|
||||
file_description: str,
|
||||
docs: List[DocumentWithVSId] = []) -> List[Document]:
|
||||
|
||||
logger.info("start summary")
|
||||
# TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题
|
||||
# merge_docs = self._drop_overlap(docs)
|
||||
# # 将merge_docs中的句子合并成一个文档
|
||||
# text = self._join_docs(merge_docs)
|
||||
# 根据段落于句子的分隔符,将文档分成chunk,每个chunk长度小于token_max长度
|
||||
|
||||
"""
|
||||
这个过程分成两个部分:
|
||||
1. 对每个文档进行处理,得到每个文档的摘要
|
||||
map_results = self.llm_chain.apply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤
|
||||
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
|
||||
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
|
||||
)
|
||||
"""
|
||||
summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
|
||||
task_briefing="描述不同方法之间的接近度和相似性,"
|
||||
"以帮助读者理解它们之间的关系。")
|
||||
print(summary_combine)
|
||||
print(summary_intermediate_steps)
|
||||
|
||||
# if len(summary_combine) == 0:
|
||||
# # 为空重新生成,数量减半
|
||||
# result_docs = [
|
||||
# Document(page_content=question_result_key, metadata=docs[i].metadata)
|
||||
# # This uses metadata from the docs, and the textual results from `results`
|
||||
# for i, question_result_key in enumerate(
|
||||
# summary_intermediate_steps["intermediate_steps"][
|
||||
# :len(summary_intermediate_steps["intermediate_steps"]) // 2
|
||||
# ])
|
||||
# ]
|
||||
# summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
|
||||
# result_docs, token_max=self.token_max
|
||||
# )
|
||||
logger.info("end summary")
|
||||
doc_ids = ",".join([doc.id for doc in docs])
|
||||
_metadata = {
|
||||
"file_description": file_description,
|
||||
"summary_intermediate_steps": summary_intermediate_steps,
|
||||
"doc_ids": doc_ids
|
||||
}
|
||||
summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
|
||||
|
||||
return [summary_combine_doc]
|
||||
|
||||
def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
|
||||
"""
|
||||
# 将文档中page_content句子叠加的部分去掉
|
||||
:param docs:
|
||||
:param separator:
|
||||
:return:
|
||||
"""
|
||||
merge_docs = []
|
||||
|
||||
pre_doc = None
|
||||
for doc in docs:
|
||||
# 第一个文档直接添加
|
||||
if len(merge_docs) == 0:
|
||||
pre_doc = doc.page_content
|
||||
merge_docs.append(doc.page_content)
|
||||
continue
|
||||
|
||||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||||
# 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
|
||||
for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
|
||||
# 每次迭代删除前面的字符
|
||||
pre_doc = pre_doc[1:]
|
||||
if doc.page_content[:len(pre_doc)] == pre_doc:
|
||||
# 删除下一个开头重叠的部分
|
||||
merge_docs.append(doc.page_content[len(pre_doc):])
|
||||
break
|
||||
|
||||
pre_doc = doc.page_content
|
||||
|
||||
return merge_docs
|
||||
|
||||
def _join_docs(self, docs: List[str]) -> Optional[str]:
|
||||
text = self._separator.join(docs)
|
||||
text = text.strip()
|
||||
if text == "":
|
||||
return None
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
docs = [
|
||||
|
||||
'梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
|
||||
|
||||
'梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
|
||||
|
||||
'使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
|
||||
'值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
|
||||
]
|
||||
_OVERLAP_SIZE = 1
|
||||
separator: str = "\n\n"
|
||||
merge_docs = []
|
||||
# 将文档中page_content句子叠加的部分去掉,
|
||||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||
pre_doc = None
|
||||
for doc in docs:
|
||||
# 第一个文档直接添加
|
||||
if len(merge_docs) == 0:
|
||||
pre_doc = doc
|
||||
merge_docs.append(doc)
|
||||
continue
|
||||
|
||||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||||
# 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
|
||||
for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
|
||||
# 每次迭代删除前面的字符
|
||||
pre_doc = pre_doc[1:]
|
||||
if doc[:len(pre_doc)] == pre_doc:
|
||||
# 删除下一个开头重叠的部分
|
||||
page_content = doc[len(pre_doc):]
|
||||
merge_docs.append(page_content)
|
||||
|
||||
pre_doc = doc
|
||||
break
|
||||
|
||||
# 将merge_docs中的句子合并成一个文档
|
||||
text = separator.join(merge_docs)
|
||||
text = text.strip()
|
||||
|
||||
print(text)
|
||||
220
server/knowledge_base/kb_summary_api.py
Normal file
220
server/knowledge_base/kb_summary_api.py
Normal file
@ -0,0 +1,220 @@
|
||||
from fastapi import Body
|
||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||
OVERLAP_SIZE,
|
||||
logger, log_verbose, )
|
||||
from server.knowledge_base.utils import (list_files_from_folder)
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from typing import List, Optional
|
||||
from server.knowledge_base.kb_summary.base import KBSummaryService
|
||||
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
|
||||
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
|
||||
def recreate_summary_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
):
|
||||
"""
|
||||
重建单个知识库文件摘要
|
||||
:param max_tokens:
|
||||
:param model_name:
|
||||
:param temperature:
|
||||
:param file_description:
|
||||
:param knowledge_base_name:
|
||||
:param allow_empty_kb:
|
||||
:param vs_type:
|
||||
:param embed_model:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def output():
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
# 重新创建知识库
|
||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||
kb_summary.drop_kb_summary()
|
||||
kb_summary.create_kb_summary()
|
||||
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
|
||||
i = 0
|
||||
for i, file_name in enumerate(files):
|
||||
|
||||
doc_infos = kb.list_docs(file_name=file_name)
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_infos)
|
||||
|
||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||
if status_kb_summary:
|
||||
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
|
||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
|
||||
return StreamingResponse(output(), media_type="text/event-stream")
|
||||
|
||||
|
||||
def summary_file_to_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_name: str = Body(..., examples=["test.pdf"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
):
|
||||
"""
|
||||
单个知识库根据文件名称摘要
|
||||
:param model_name:
|
||||
:param max_tokens:
|
||||
:param temperature:
|
||||
:param file_description:
|
||||
:param file_name:
|
||||
:param knowledge_base_name:
|
||||
:param allow_empty_kb:
|
||||
:param vs_type:
|
||||
:param embed_model:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def output():
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
# 重新创建知识库
|
||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||
kb_summary.create_kb_summary()
|
||||
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
|
||||
doc_infos = kb.list_docs(file_name=file_name)
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_infos)
|
||||
|
||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||
if status_kb_summary:
|
||||
logger.info(f" {file_name} 总结完成")
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"{file_name} 总结完成",
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
|
||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
|
||||
return StreamingResponse(output(), media_type="text/event-stream")
|
||||
|
||||
|
||||
def summary_doc_ids_to_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_ids: List = Body([], examples=[["uuid"]]),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
单个知识库根据doc_ids摘要
|
||||
:param knowledge_base_name:
|
||||
:param doc_ids:
|
||||
:param model_name:
|
||||
:param max_tokens:
|
||||
:param temperature:
|
||||
:param file_description:
|
||||
:param vs_type:
|
||||
:param embed_model:
|
||||
:return:
|
||||
"""
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists():
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
|
||||
else:
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
|
||||
doc_infos = kb.get_doc_by_ids(ids=doc_ids)
|
||||
# doc_infos转换成DocumentWithVSId包装的对象
|
||||
doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)]
|
||||
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_info_with_ids)
|
||||
|
||||
# 将docs转换成dict
|
||||
resp_summarize = [{**doc.dict()} for doc in docs]
|
||||
|
||||
return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize})
|
||||
@ -12,6 +12,8 @@ from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.models.conversation_model import ConversationModel
|
||||
from server.db.models.message_model import MessageModel
|
||||
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
|
||||
from server.db.repository.knowledge_metadata_repository import add_summary_to_db
|
||||
|
||||
from server.db.base import Base, engine
|
||||
from server.db.session import session_scope
|
||||
import os
|
||||
|
||||
10
server/knowledge_base/model/kb_document_model.py
Normal file
10
server/knowledge_base/model/kb_document_model.py
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class DocumentWithVSId(Document):
|
||||
"""
|
||||
矢量化后的文档
|
||||
"""
|
||||
id: str = None
|
||||
|
||||
@ -71,7 +71,7 @@ def list_files_from_folder(kb_name: str):
|
||||
for target_entry in target_it:
|
||||
process_entry(target_entry)
|
||||
elif entry.is_file():
|
||||
result.append(entry.path)
|
||||
result.append(os.path.relpath(entry.path, doc_path))
|
||||
elif entry.is_dir():
|
||||
with os.scandir(entry.path) as it:
|
||||
for sub_entry in it:
|
||||
@ -85,6 +85,7 @@ def list_files_from_folder(kb_name: str):
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||
"MHTMLLoader": ['.mhtml'],
|
||||
"UnstructuredMarkdownLoader": ['.md'],
|
||||
"JSONLoader": [".json"],
|
||||
"JSONLinesLoader": [".jsonl"],
|
||||
@ -106,6 +107,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||
"UnstructuredWordDocumentLoader": ['.docx', 'doc'],
|
||||
"UnstructuredXMLLoader": ['.xml'],
|
||||
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
|
||||
"EverNoteLoader": ['.enex'],
|
||||
"UnstructuredFileLoader": ['.txt'],
|
||||
}
|
||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||
@ -311,6 +313,9 @@ class KnowledgeFile:
|
||||
else:
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
if not docs:
|
||||
return []
|
||||
|
||||
print(f"文档切分示例:{docs[0]}")
|
||||
if zh_title_enhance:
|
||||
docs = func_zh_title_enhance(docs)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import os
|
||||
from fastchat.conversation import Conversation
|
||||
from server.model_workers.base import *
|
||||
from server.utils import get_httpx_client
|
||||
@ -19,16 +20,16 @@ class AzureWorker(ApiModelWorker):
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
data = dict(
|
||||
messages=params.messages,
|
||||
temperature=params.temperature,
|
||||
max_tokens=params.max_tokens,
|
||||
max_tokens=params.max_tokens if params.max_tokens else None,
|
||||
stream=True,
|
||||
)
|
||||
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
|
||||
@ -47,6 +48,7 @@ class AzureWorker(ApiModelWorker):
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
print(data)
|
||||
for line in response.iter_lines():
|
||||
if not line.strip() or "[DONE]" in line:
|
||||
continue
|
||||
@ -60,6 +62,7 @@ class AzureWorker(ApiModelWorker):
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
print(text)
|
||||
else:
|
||||
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
|
||||
|
||||
@ -91,4 +94,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21008)
|
||||
uvicorn.run(app, port=21008)
|
||||
@ -119,7 +119,9 @@ class MiniMaxWorker(ApiModelWorker):
|
||||
with get_httpx_client() as client:
|
||||
result = []
|
||||
i = 0
|
||||
for texts in params.texts[i:i+10]:
|
||||
batch_size = 10
|
||||
while i < len(params.texts):
|
||||
texts = params.texts[i:i+batch_size]
|
||||
data["texts"] = texts
|
||||
r = client.post(url, headers=headers, json=data).json()
|
||||
if embeddings := r.get("vectors"):
|
||||
@ -137,7 +139,7 @@ class MiniMaxWorker(ApiModelWorker):
|
||||
}
|
||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||
return data
|
||||
i += 10
|
||||
i += batch_size
|
||||
return {"code": 200, "data": embeddings}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
|
||||
@ -188,9 +188,11 @@ class QianFanWorker(ApiModelWorker):
|
||||
with get_httpx_client() as client:
|
||||
result = []
|
||||
i = 0
|
||||
for texts in params.texts[i:i+10]:
|
||||
batch_size = 10
|
||||
while i < len(params.texts):
|
||||
texts = params.texts[i:i+batch_size]
|
||||
resp = client.post(url, json={"input": texts}).json()
|
||||
if "error_cdoe" in resp:
|
||||
if "error_code" in resp:
|
||||
data = {
|
||||
"code": resp["error_code"],
|
||||
"msg": resp["error_msg"],
|
||||
@ -206,7 +208,7 @@ class QianFanWorker(ApiModelWorker):
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
||||
result += embeddings
|
||||
i += 10
|
||||
i += batch_size
|
||||
return {"code": 200, "data": result}
|
||||
|
||||
# TODO: qianfan支持续写模型
|
||||
|
||||
@ -40,11 +40,10 @@ def get_ChatOpenAI(
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> ChatOpenAI:
|
||||
## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
|
||||
config_models = list_config_llm_models()
|
||||
|
||||
## 非Langchain原生支持的模型,走Fschat封装
|
||||
config = get_model_worker_config(model_name)
|
||||
if model_name == "openai-api":
|
||||
model_name = config.get("model_name")
|
||||
|
||||
model = ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
@ -57,10 +56,8 @@ def get_ChatOpenAI(
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_OpenAI(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
@ -71,67 +68,22 @@ def get_OpenAI(
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> OpenAI:
|
||||
## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
|
||||
config_models = list_config_llm_models()
|
||||
if model_name in config_models.get("langchain", {}):
|
||||
config = config_models["langchain"][model_name]
|
||||
if model_name == "Azure-OpenAI":
|
||||
model = AzureOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
deployment_name=config.get("deployment_name"),
|
||||
model_version=config.get("model_version"),
|
||||
openai_api_type=config.get("openai_api_type"),
|
||||
openai_api_base=config.get("api_base_url"),
|
||||
openai_api_version=config.get("api_version"),
|
||||
openai_api_key=config.get("api_key"),
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
echo=echo,
|
||||
)
|
||||
|
||||
elif model_name == "OpenAI":
|
||||
model = OpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=config.get("model_name"),
|
||||
openai_api_base=config.get("api_base_url"),
|
||||
openai_api_key=config.get("api_key"),
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
echo=echo,
|
||||
)
|
||||
elif model_name == "Anthropic":
|
||||
model = Anthropic(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
model_name=config.get("model_name"),
|
||||
anthropic_api_key=config.get("api_key"),
|
||||
echo=echo,
|
||||
)
|
||||
## TODO 支持其他的Langchain原生支持的模型
|
||||
else:
|
||||
## 非Langchain原生支持的模型,走Fschat封装
|
||||
config = get_model_worker_config(model_name)
|
||||
model = OpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
echo=echo,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
config = get_model_worker_config(model_name)
|
||||
if model_name == "openai-api":
|
||||
model_name = config.get("model_name")
|
||||
model = OpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
echo=echo,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@ -630,7 +582,7 @@ def get_httpx_client(
|
||||
for host in os.environ.get("no_proxy", "").split(","):
|
||||
if host := host.strip():
|
||||
# default_proxies.update({host: None}) # Origin code
|
||||
default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error
|
||||
default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error
|
||||
|
||||
# merge default proxies with user provided proxies
|
||||
if isinstance(proxies, str):
|
||||
@ -714,7 +666,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
||||
from configs.basic_config import BASE_TEMP_DIR
|
||||
import tempfile
|
||||
|
||||
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||
path = os.path.join(BASE_TEMP_DIR, id)
|
||||
if os.path.isdir(path):
|
||||
return path, id
|
||||
|
||||
@ -36,6 +36,7 @@ from server.utils import (fschat_controller_address, fschat_model_worker_address
|
||||
fschat_openai_api_address, set_httpx_config, get_httpx_client,
|
||||
get_model_worker_config, get_all_model_worker_configs,
|
||||
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
import argparse
|
||||
from typing import Tuple, List, Dict
|
||||
from configs import VERSION
|
||||
@ -866,6 +867,8 @@ async def start_main_server():
|
||||
logger.info("Process status: %s", p)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 确保数据库表被创建
|
||||
create_tables()
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
44
tests/api/test_kb_summary_api.py
Normal file
44
tests/api/test_kb_summary_api.py
Normal file
@ -0,0 +1,44 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
kb = "samples"
|
||||
file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md"
|
||||
doc_ids = [
|
||||
"357d580f-fdf7-495c-b58b-595a398284e8",
|
||||
"c7338773-2e83-4671-b237-1ad20335b0f0",
|
||||
"6da613d1-327d-466f-8c1a-b32e6f461f47"
|
||||
]
|
||||
|
||||
|
||||
def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"):
|
||||
url = api_base_url + api
|
||||
print("\n文件摘要:")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb,
|
||||
"file_name": file_name
|
||||
}, stream=True)
|
||||
for chunk in r.iter_content(None):
|
||||
data = json.loads(chunk)
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
|
||||
def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"):
|
||||
url = api_base_url + api
|
||||
print("\n文件摘要:")
|
||||
r = requests.post(url, json={"knowledge_base_name": kb,
|
||||
"doc_ids": doc_ids
|
||||
}, stream=True)
|
||||
for chunk in r.iter_content(None):
|
||||
data = json.loads(chunk)
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data)
|
||||
@ -1,8 +1,11 @@
|
||||
import streamlit as st
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from streamlit_modal import Modal
|
||||
from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
||||
from server.knowledge_base.utils import LOADER_DICT
|
||||
@ -47,6 +50,53 @@ def upload_temp_docs(files, _api: ApiRequest) -> str:
|
||||
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
||||
|
||||
|
||||
def parse_command(text: str, modal: Modal) -> bool:
|
||||
'''
|
||||
检查用户是否输入了自定义命令,当前支持:
|
||||
/new {session_name}。如果未提供名称,默认为“会话X”
|
||||
/del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。
|
||||
/clear {session_name}。如果未提供名称,默认清除当前会话
|
||||
/help。查看命令帮助
|
||||
返回值:输入的是命令返回True,否则返回False
|
||||
'''
|
||||
if m := re.match(r"/([^\s]+)\s*(.*)", text):
|
||||
cmd, name = m.groups()
|
||||
name = name.strip()
|
||||
conv_names = chat_box.get_chat_names()
|
||||
if cmd == "help":
|
||||
modal.open()
|
||||
elif cmd == "new":
|
||||
if not name:
|
||||
i = 1
|
||||
while True:
|
||||
name = f"会话{i}"
|
||||
if name not in conv_names:
|
||||
break
|
||||
i += 1
|
||||
if name in st.session_state["conversation_ids"]:
|
||||
st.error(f"该会话名称 “{name}” 已存在")
|
||||
time.sleep(1)
|
||||
else:
|
||||
st.session_state["conversation_ids"][name] = uuid.uuid4().hex
|
||||
st.session_state["cur_conv_name"] = name
|
||||
elif cmd == "del":
|
||||
name = name or st.session_state.get("cur_conv_name")
|
||||
if len(conv_names) == 1:
|
||||
st.error("这是最后一个会话,无法删除")
|
||||
time.sleep(1)
|
||||
elif not name or name not in st.session_state["conversation_ids"]:
|
||||
st.error(f"无效的会话名称:“{name}”")
|
||||
time.sleep(1)
|
||||
else:
|
||||
st.session_state["conversation_ids"].pop(name, None)
|
||||
chat_box.del_chat_name(name)
|
||||
st.session_state["cur_conv_name"] = ""
|
||||
elif cmd == "clear":
|
||||
chat_box.reset_history(name=name or None)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.session_state.setdefault("conversation_ids", {})
|
||||
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
||||
@ -60,25 +110,15 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
)
|
||||
chat_box.init_session()
|
||||
|
||||
# 弹出自定义命令帮助信息
|
||||
modal = Modal("自定义命令", key="cmd_help", max_width="500")
|
||||
if modal.is_open():
|
||||
with modal.container():
|
||||
cmds = [x for x in parse_command.__doc__.split("\n") if x.strip().startswith("/")]
|
||||
st.write("\n\n".join(cmds))
|
||||
|
||||
with st.sidebar:
|
||||
# 多会话
|
||||
cols = st.columns([3, 1])
|
||||
conv_name = cols[0].text_input("会话名称")
|
||||
with cols[1]:
|
||||
if st.button("添加"):
|
||||
if not conv_name or conv_name in st.session_state["conversation_ids"]:
|
||||
st.error("请指定有效的会话名称")
|
||||
else:
|
||||
st.session_state["conversation_ids"][conv_name] = uuid.uuid4().hex
|
||||
st.session_state["cur_conv_name"] = conv_name
|
||||
st.session_state["conv_name"] = ""
|
||||
if st.button("删除"):
|
||||
if not conv_name or conv_name not in st.session_state["conversation_ids"]:
|
||||
st.error("请指定有效的会话名称")
|
||||
else:
|
||||
st.session_state["conversation_ids"].pop(conv_name, None)
|
||||
st.session_state["cur_conv_name"] = ""
|
||||
|
||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||
index = 0
|
||||
if st.session_state.get("cur_conv_name") in conv_names:
|
||||
@ -236,7 +276,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
# Display chat messages from history on app rerun
|
||||
chat_box.output_messages()
|
||||
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter "
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||
|
||||
def on_feedback(
|
||||
feedback,
|
||||
@ -256,139 +296,142 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
}
|
||||
|
||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||
history = get_messages_history(history_len)
|
||||
chat_box.user_say(prompt)
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
message_id = ""
|
||||
r = api.chat_chat(prompt,
|
||||
history=history,
|
||||
conversation_id=conversation_id,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t.get("text", "")
|
||||
chat_box.update_msg(text)
|
||||
message_id = t.get("message_id", "")
|
||||
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
||||
st.rerun()
|
||||
else:
|
||||
history = get_messages_history(history_len)
|
||||
chat_box.user_say(prompt)
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
message_id = ""
|
||||
r = api.chat_chat(prompt,
|
||||
history=history,
|
||||
conversation_id=conversation_id,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t.get("text", "")
|
||||
chat_box.update_msg(text)
|
||||
message_id = t.get("message_id", "")
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=message_id,
|
||||
on_submit=on_feedback,
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=message_id,
|
||||
on_submit=on_feedback,
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||
chat_box.ai_say([
|
||||
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
])
|
||||
else:
|
||||
chat_box.ai_say([
|
||||
f"正在思考...",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
])
|
||||
text = ""
|
||||
ans = ""
|
||||
for d in api.agent_chat(prompt,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
):
|
||||
try:
|
||||
d = json.loads(d)
|
||||
except:
|
||||
pass
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
if chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
if chunk := d.get("final_answer"):
|
||||
ans += chunk
|
||||
chat_box.update_msg(ans, element_index=0)
|
||||
if chunk := d.get("tools"):
|
||||
text += "\n\n".join(d.get("tools", []))
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||
chat_box.update_msg(text, element_index=1, streaming=False)
|
||||
elif dialogue_mode == "知识库问答":
|
||||
chat_box.ai_say([
|
||||
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
f"正在查询知识库 `{selected_kb}` ...",
|
||||
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
|
||||
])
|
||||
else:
|
||||
text = ""
|
||||
for d in api.knowledge_base_chat(prompt,
|
||||
knowledge_base_name=selected_kb,
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
elif dialogue_mode == "文件对话":
|
||||
if st.session_state["file_chat_id"] is None:
|
||||
st.error("请先上传文件再进行对话")
|
||||
st.stop()
|
||||
chat_box.ai_say([
|
||||
f"正在思考...",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
||||
Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
ans = ""
|
||||
for d in api.agent_chat(prompt,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
):
|
||||
try:
|
||||
d = json.loads(d)
|
||||
except:
|
||||
pass
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
if chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
if chunk := d.get("final_answer"):
|
||||
ans += chunk
|
||||
chat_box.update_msg(ans, element_index=0)
|
||||
if chunk := d.get("tools"):
|
||||
text += "\n\n".join(d.get("tools", []))
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||
chat_box.update_msg(text, element_index=1, streaming=False)
|
||||
elif dialogue_mode == "知识库问答":
|
||||
chat_box.ai_say([
|
||||
f"正在查询知识库 `{selected_kb}` ...",
|
||||
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.knowledge_base_chat(prompt,
|
||||
knowledge_base_name=selected_kb,
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
elif dialogue_mode == "文件对话":
|
||||
if st.session_state["file_chat_id"] is None:
|
||||
st.error("请先上传文件再进行对话")
|
||||
st.stop()
|
||||
chat_box.ai_say([
|
||||
f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
||||
Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.file_chat(prompt,
|
||||
knowledge_id=st.session_state["file_chat_id"],
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
chat_box.ai_say([
|
||||
f"正在执行 `{search_engine}` 搜索...",
|
||||
Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt,
|
||||
search_engine_name=search_engine,
|
||||
top_k=se_top_k,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
split_result=se_top_k > 1):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
text = ""
|
||||
for d in api.file_chat(prompt,
|
||||
knowledge_id=st.session_state["file_chat_id"],
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
chat_box.ai_say([
|
||||
f"正在执行 `{search_engine}` 搜索...",
|
||||
Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt,
|
||||
search_engine_name=search_engine,
|
||||
top_k=se_top_k,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
split_result=se_top_k > 1):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
|
||||
if st.session_state.get("need_rerun"):
|
||||
st.session_state["need_rerun"] = False
|
||||
|
||||
@ -34,14 +34,19 @@ def config_aggrid(
|
||||
use_checkbox=use_checkbox,
|
||||
# pre_selected_rows=st.session_state.get("selected_rows", [0]),
|
||||
)
|
||||
gb.configure_pagination(
|
||||
enabled=True,
|
||||
paginationAutoPageSize=False,
|
||||
paginationPageSize=10
|
||||
)
|
||||
return gb
|
||||
|
||||
|
||||
def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]:
|
||||
'''
|
||||
"""
|
||||
check whether a doc file exists in local knowledge base folder.
|
||||
return the file's name and path if it exists.
|
||||
'''
|
||||
"""
|
||||
if selected_rows:
|
||||
file_name = selected_rows[0]["file_name"]
|
||||
file_path = get_file_path(kb, file_name)
|
||||
|
||||
@ -85,7 +85,7 @@ class ApiRequest:
|
||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||
while retry > 0:
|
||||
try:
|
||||
print(kwargs)
|
||||
# print(kwargs)
|
||||
if stream:
|
||||
return self.client.stream("POST", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
@ -134,7 +134,7 @@ class ApiRequest:
|
||||
if as_json:
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
pprint(data, depth=1)
|
||||
# pprint(data, depth=1)
|
||||
yield data
|
||||
except Exception as e:
|
||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||
@ -166,7 +166,7 @@ class ApiRequest:
|
||||
if as_json:
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
pprint(data, depth=1)
|
||||
# pprint(data, depth=1)
|
||||
yield data
|
||||
except Exception as e:
|
||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||
@ -276,8 +276,8 @@ class ApiRequest:
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/fastchat",
|
||||
@ -288,23 +288,25 @@ class ApiRequest:
|
||||
return self._httpx_stream2generator(response)
|
||||
|
||||
def chat_chat(
|
||||
self,
|
||||
query: str,
|
||||
conversation_id: str = None,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODELS[0],
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
prompt_name: str = "default",
|
||||
**kwargs,
|
||||
self,
|
||||
query: str,
|
||||
conversation_id: str = None,
|
||||
history_len: int = -1,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODELS[0],
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
prompt_name: str = "default",
|
||||
**kwargs,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/chat接口 #TODO: 考虑是否返回json
|
||||
对应api.py/chat/chat接口
|
||||
'''
|
||||
data = {
|
||||
"query": query,
|
||||
"conversation_id": conversation_id,
|
||||
"history_len": history_len,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
@ -313,8 +315,8 @@ class ApiRequest:
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
@ -342,8 +344,8 @@ class ApiRequest:
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post("/chat/agent_chat", json=data, stream=True)
|
||||
return self._httpx_stream2generator(response)
|
||||
@ -377,8 +379,8 @@ class ApiRequest:
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/knowledge_base_chat",
|
||||
@ -452,8 +454,8 @@ class ApiRequest:
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/file_chat",
|
||||
@ -491,8 +493,8 @@ class ApiRequest:
|
||||
"split_result": split_result,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/search_engine_chat",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user