This commit is contained in:
hzg0601 2023-12-06 20:42:32 +08:00
commit 4c2fda7200
48 changed files with 1435 additions and 550 deletions

View File

@ -148,7 +148,7 @@ $ python startup.py -a
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
### 项目交流群
<img src="img/qr_code_74.jpg" alt="二维码" width="300" />
<img src="img/qr_code_76.jpg" alt="二维码" width="300" />
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

View File

@ -5,4 +5,4 @@ from .server_config import *
from .prompt_config import *
VERSION = "v0.2.8-preview"
VERSION = "v0.2.9-preview"

View File

@ -106,7 +106,7 @@ kbs_config = {
# TextSplitter配置项如果你不明白其中的含义就不要修改。
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", ## 选择tiktoken则使用openai的方法
"source": "huggingface", # 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {

View File

@ -15,9 +15,11 @@ EMBEDDING_DEVICE = "auto"
EMBEDDING_KEYWORD_FILE = "keywords.txt"
EMBEDDING_MODEL_OUTPUT_PATH = "output"
# 要运行的 LLM 名称,可以包括本地模型和在线模型。
# 第一个将作为 API 和 WEBUI 的默认模型
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
# 要运行的 LLM 名称,可以包括本地模型和在线模型。列表中本地模型将在启动项目时全部加载。
# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
# 在这里我们使用目前主流的两个离线模型其中chatglm3-6b 为默认加载模型。
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",
# AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0])
Agent_MODEL = None
@ -112,10 +114,10 @@ ONLINE_LLM_MODEL = {
"api_key": "",
"provider": "AzureWorker",
},
# 昆仑万维天工 API https://model-platform.tiangong.cn/
"tiangong-api": {
"version":"SkyChat-MegaVerse",
"version": "SkyChat-MegaVerse",
"api_key": "",
"secret_key": "",
"provider": "TianGongWorker",
@ -163,6 +165,25 @@ MODEL_PATH = {
"chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"chatglm3-6b-base": "THUDM/chatglm3-6b-base",
"Qwen-1_8B": "Qwen/Qwen-1_8B",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
"Qwen-72B": "Qwen/Qwen-72B",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat",
@ -204,18 +225,11 @@ MODEL_PATH = {
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", # 确保已经安装了auto-gptq optimum flash-attn
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4", # 确保已经安装了auto-gptq optimum flash-attn
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat", # 更多01-ai模型尚未进行测试。如果需要使用请自行测试。
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
},
}
@ -242,11 +256,11 @@ VLLM_MODEL_DICT = {
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
# 注意bloom系列的tokenizer与model是分离的因此虽然vllm支持但与fschat框架不兼容
# "bloom":"bigscience/bloom",
# "bloomz":"bigscience/bloomz",
# "bloomz-560m":"bigscience/bloomz-560m",
# "bloomz-7b1":"bigscience/bloomz-7b1",
# "bloomz-1b7":"bigscience/bloomz-1b7",
# "bloom": "bigscience/bloom",
# "bloomz": "bigscience/bloomz",
# "bloomz-560m": "bigscience/bloomz-560m",
# "bloomz-7b1": "bigscience/bloomz-7b1",
# "bloomz-1b7": "bigscience/bloomz-1b7",
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
@ -273,10 +287,23 @@ VLLM_MODEL_DICT = {
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"Qwen-1_8B": "Qwen/Qwen-1_8B",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
"Qwen-72B": "Qwen/Qwen-72B",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",

View File

@ -1,7 +1,6 @@
# prompt模板使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
# 本配置文件支持热加载修改prompt模板后无需重启服务。
# LLM对话支持的变量
# - input: 用户输入内容
@ -17,125 +16,112 @@
# - input: 用户输入内容
# - agent_scratchpad: Agent的思维记录
PROMPT_TEMPLATES = {}
PROMPT_TEMPLATES = {
"llm_chat": {
"default":
'{{ input }}',
PROMPT_TEMPLATES["llm_chat"] = {
"default": "{{ input }}",
"with_history":
"""The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
"with_history":
'The following is a friendly conversation between a human and an AI. '
'The AI is talkative and provides lots of specific details from its context. '
'If the AI does not know the answer to a question, it truthfully says it does not know.\n\n'
'Current conversation:\n'
'{history}\n'
'Human: {input}\n'
'AI:',
Current conversation:
{history}
Human: {input}
AI:""",
"py":
"""
你是一个聪明的代码助手请你给我写出简单的py代码。 \n
{{ input }}
""",
}
PROMPT_TEMPLATES["knowledge_base_chat"] = {
"default":
"""
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题>
""",
"text":
"""
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题>
""",
"Empty": # 搜不到知识库的时候使用
"""
请你回答我的问题:
{{ question }}
\n
""",
}
PROMPT_TEMPLATES["search_engine_chat"] = {
"default":
"""
<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 </指令>
<已知信息>{{ context }}</已知信息>
<问题>{{ question }}</问题>
""",
"search":
"""
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>、
<问题>{{ question }}</问题>
""",
}
PROMPT_TEMPLATES["agent_chat"] = {
"default":
"""
Answer the following questions as best you can. If it is in order, you can use some tools appropriately.You have access to the following tools:
{tools}
Please note that the "知识库查询工具" is information about the "西交利物浦大学" ,and if a question is asked about it, you must answer with the knowledge base
Please note that the "天气查询工具" can only be used once since Question begin.
Use the following format:
Question: the input question you must answer1
Thought: you should always think about what to do and what tools to use.
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
history: {history}
Question: {input}
Thought: {agent_scratchpad}
""",
"ChatGLM3":
"""
You can answer using the tools, or answer directly using your knowledge without using the tools.Respond to the human as helpfully and accurately as possible.
You have access to the following tools:
{tools}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or [{tool_names}]
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
history: {history}
Question: {input}
Thought: {agent_scratchpad}
""",
"py":
'你是一个聪明的代码助手请你给我写出简单的py代码。 \n'
'{{ input }}',
},
"knowledge_base_chat": {
"default":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,'
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"text":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"empty": # 搜不到知识库的时候使用
'请你回答我的问题:\n'
'{{ question }}\n\n',
},
"search_engine_chat": {
"default":
'<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。'
'如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"search":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
},
"agent_chat": {
"default":
'Answer the following questions as best you can. If it is in order, you can use some tools appropriately. '
'You have access to the following tools:\n\n'
'{tools}\n\n'
'Use the following format:\n'
'Question: the input question you must answer1\n'
'Thought: you should always think about what to do and what tools to use.\n'
'Action: the action to take, should be one of [{tool_names}]\n'
'Action Input: the input to the action\n'
'Observation: the result of the action\n'
'... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n'
'Thought: I now know the final answer\n'
'Final Answer: the final answer to the original input question\n'
'Begin!\n\n'
'history: {history}\n\n'
'Question: {input}\n\n'
'Thought: {agent_scratchpad}\n',
"ChatGLM3":
'You can answer using the tools, or answer directly using your knowledge without using the tools. '
'Respond to the human as helpfully and accurately as possible.\n'
'You have access to the following tools:\n'
'{tools}\n'
'Use a json blob to specify a tool by providing an action key (tool name) '
'and an action_input key (tool input).\n'
'Valid "action" values: "Final Answer" or [{tool_names}]'
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
'```\n'
'{{{{\n'
' "action": $TOOL_NAME,\n'
' "action_input": $INPUT\n'
'}}}}\n'
'```\n\n'
'Follow this format:\n\n'
'Question: input question to answer\n'
'Thought: consider previous and subsequent steps\n'
'Action:\n'
'```\n'
'$JSON_BLOB\n'
'```\n'
'Observation: action result\n'
'... (repeat Thought/Action/Observation N times)\n'
'Thought: I know what to respond\n'
'Action:\n'
'```\n'
'{{{{\n'
' "action": "Final Answer",\n'
' "action_input": "Final response to human"\n'
'}}}}\n'
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. '
'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n'
'history: {history}\n\n'
'Question: {input}\n\n'
'Thought: {agent_scratchpad}',
}
}

View File

@ -66,7 +66,7 @@ FSCHAT_MODEL_WORKERS = {
# "no_register": False,
# "embed_in_truncate": False,
# 以下为vllm_woker配置参数,注意使用vllm必须有gpu仅在Linux测试通过
# 以下为vllm_worker配置参数,注意使用vllm必须有gpu仅在Linux测试通过
# tokenizer = model_path # 如果tokenizer与model_path不一致在此处添加
# 'tokenizer_mode':'auto',
@ -93,14 +93,14 @@ FSCHAT_MODEL_WORKERS = {
},
# 可以如下示例方式更改默认配置
# "Qwen-7B-Chat": { # 使用default中的IP和端口
# "Qwen-1_8B-Chat": { # 使用default中的IP和端口
# "device": "cpu",
# },
"chatglm3-6b": { # 使用default中的IP和端口
"chatglm3-6b": { # 使用default中的IP和端口
"device": "cuda",
},
#以下配置可以不用修改在model_config中设置启动的模型
# 以下配置可以不用修改在model_config中设置启动的模型
"zhipu-api": {
"port": 21001,
},

View File

@ -1,13 +1,13 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from document_loaders.ocr import get_ocr
class RapidOCRLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def img2text(filepath):
from rapidocr_onnxruntime import RapidOCR
resp = ""
ocr = RapidOCR()
ocr = get_ocr()
result, _ = ocr(filepath)
if result:
ocr_result = [line[1] for line in result]

View File

@ -1,5 +1,6 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from document_loaders.ocr import get_ocr
import tqdm
@ -7,9 +8,8 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def pdf2text(filepath):
import fitz # pyMuPDF里面的fitz包不要与pip install fitz混淆
from rapidocr_onnxruntime import RapidOCR
import numpy as np
ocr = RapidOCR()
ocr = get_ocr()
doc = fitz.open(filepath)
resp = ""

18
document_loaders/ocr.py Normal file
View File

@ -0,0 +1,18 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
try:
from rapidocr_paddle import RapidOCR
except ImportError:
from rapidocr_onnxruntime import RapidOCR
def get_ocr(use_cuda: bool = True) -> "RapidOCR":
try:
from rapidocr_paddle import RapidOCR
ocr = RapidOCR(det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda)
except ImportError:
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
return ocr

Binary file not shown.

Before

Width:  |  Height:  |  Size: 267 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 198 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 270 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 174 KiB

BIN
img/qr_code_76.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 182 KiB

@ -1 +1 @@
Subproject commit f789e5dde10f91136012f3470c020c8d34572436
Subproject commit 9a3fa7a77f8748748b1c656fe8919ad5c4c63e3f

View File

@ -1,63 +1,75 @@
# API requirements
langchain>=0.0.334
langchain==0.0.344
langchain-experimental>=0.0.42
fschat[model_worker]>=0.2.33
pydantic==1.10.13
fschat>=0.2.33
xformers>=0.0.22.post7
openai~=0.28.1
openai>=1.3.7
sentence_transformers
transformers>=4.35.2
torch==2.1.0 ##on win, install the cuda version manually if you want use gpu
torchvision #on win, install the cuda version manually if you want use gpu
torchaudio #on win, install the cuda version manually if you want use gpu
torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/
torchvision #on Windows system, install the cuda version manually from https://pytorch.org/
torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/
fastapi>=0.104
nltk>=3.8.1
uvicorn~=0.23.1
uvicorn>=0.24.0.post1
starlette~=0.27.0
pydantic<2
unstructured[all-docs]>=0.11.0
unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
faiss-cpu
accelerate
spacy
faiss-cpu # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
accelerate>=0.24.1
spacy>=3.7.2
PyMuPDF
rapidocr_onnxruntime
requests
pathlib
pytest
numexpr
strsimpy
markdownify
tiktoken
tqdm
requests>=2.31.0
pathlib>=1.0.1
pytest>=7.4.3
numexpr>=2.8.7
strsimpy>=0.2.1
markdownify>=0.11.6
tiktoken>=0.5.1
tqdm>=4.66.1
websockets
numpy~=1.24.4
pandas~=2.0.3
einops
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.2; sys_platform == "linux"
# online api libs dependencies
# optional document loaders
# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files
# jq # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
# beautifulsoup4 # for .mhtml files
# pysrt # for .srt files
# Online api libs dependencies
# zhipuai>=1.0.7
# dashscope>=1.10.0
# qianfan>=0.2.0
# volcengine>=1.0.106
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# pymilvus>=2.3.3
# psycopg2
# pgvector
# pgvector>=0.2.4
# Agent and Search Tools
arxiv>=2.0.0
youtube-search>=2.1.2
duckduckgo-search>=3.9.9
metaphor-python>=0.1.23
# WebUI requirements
streamlit~=1.28.2 # # on win, make sure write its path in environment variable
streamlit>=1.29.0 # do remember to add streamlit to environment variables if you use windows
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.2.3
streamlit-chatbox>=1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]~=0.24.1
watchdog
httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0

View File

@ -1,52 +1,65 @@
# API requirements
langchain>=0.0.334
langchain==0.0.344
langchain-experimental>=0.0.42
fschat[model_worker]>=0.2.33
pydantic==1.10.13
fschat>=0.2.33
xformers>=0.0.22.post7
openai~=0.28.1
openai>=1.3.7
sentence_transformers
transformers>=4.35.2
torch==2.1.0
torchvision
torchaudio
torch==2.1.0 ##on Windows system, install the cuda version manually from https://pytorch.org/
torchvision #on Windows system, install the cuda version manually from https://pytorch.org/
torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/
fastapi>=0.104
nltk>=3.8.1
uvicorn~=0.23.1
uvicorn>=0.24.0.post1
starlette~=0.27.0
pydantic<2
unstructured[all-docs]>=0.11.0
unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
faiss-cpu
accelerate>=0.24.1
spacy
spacy>=3.7.2
PyMuPDF
rapidocr_onnxruntime
requests
pathlib
pytest
numexpr
strsimpy
markdownify
tiktoken
tqdm
requests>=2.31.0
pathlib>=1.0.1
pytest>=7.4.3
numexpr>=2.8.7
strsimpy>=0.2.1
markdownify>=0.11.6
tiktoken>=0.5.1
tqdm>=4.66.1
websockets
numpy~=1.24.4
pandas~=2.0.3
einops
transformers_stream_generator>=0.0.4
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.2; sys_platform == "linux"
httpx[brotli,http2,socks]>=0.25.2
vllm>=0.2.0; sys_platform == "linux"
# online api libs
zhipuai
dashscope>=1.10.0 # qwen
qianfan
# volcengine>=1.0.106 # fangzhou
# optional document loaders
# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files
# jq # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
# beautifulsoup4 # for .mhtml files
# pysrt # for .srt files
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# Online api libs dependencies
# zhipuai>=1.0.7
# dashscope>=1.10.0
# qianfan>=0.2.0
# volcengine>=1.0.106
# pymilvus>=2.3.3
# psycopg2
# pgvector
# pgvector>=0.2.4
# Agent and Search Tools
arxiv>=2.0.0
youtube-search>=2.1.2
duckduckgo-search>=3.9.9
metaphor-python>=0.1.23

View File

@ -1,62 +1,70 @@
langchain>=0.0.334
fschat>=0.2.32
openai
# sentence_transformers
# transformers>=4.35.0
# torch>=2.0.1
# torchvision
# torchaudio
langchain==0.0.344
pydantic==1.10.13
fschat>=0.2.33
openai>=1.3.7
fastapi>=0.104.1
python-multipart
nltk~=3.8.1
uvicorn~=0.23.1
uvicorn>=0.24.0.post1
starlette~=0.27.0
pydantic~=1.10.11
unstructured[docx,csv]>=0.10.4 # add pdf if need
unstructured[docx,csv]==0.11.0 # add pdf if need
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
numexpr>=2.8.7
strsimpy>=0.2.1
faiss-cpu
# accelerate
# spacy
# accelerate>=0.24.1
# spacy>=3.7.2
# PyMuPDF==1.22.5 # install if need pdf
# rapidocr_onnxruntime>=1.3.2 # install if need pdf
# optional document loaders
# rapidocr_paddle[gpu] # gpu accelleration for ocr of pdf and image files
# jq # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
# beautifulsoup4 # for .mhtml files
# pysrt # for .srt files
requests
pathlib
pytest
# scikit-learn
# numexpr
# vllm==0.1.7; sys_platform == "linux"
# vllm==0.2.2; sys_platform == "linux"
# online api libs
zhipuai
dashscope>=1.10.0 # qwen
# qianfan
zhipuai>=1.0.7 # zhipu
# dashscope>=1.10.0 # qwen
# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2
# pgvector
# pgvector>=0.2.4
numpy~=1.24.4
pandas~=2.0.3
streamlit~=1.28.1
streamlit>=1.29.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox==1.1.11
streamlit-antd-components>=0.2.3
streamlit-chatbox>=1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog
tqdm
httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0
tqdm>=4.66.1
websockets
einops>=0.7.0
# tiktoken
einops
# scipy
# scipy>=1.11.4
# transformers_stream_generator==0.0.4
# search engine libs
duckduckgo-search
metaphor-python
strsimpy
markdownify
# Agent and Search Tools
arxiv>=2.0.0
youtube-search>=2.1.2
duckduckgo-search>=3.9.9
metaphor-python>=0.1.23

View File

@ -1,10 +1,10 @@
# WebUI requirements
streamlit~=1.28.2
streamlit>=1.29.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.2.3
streamlit-chatbox>=1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]~=0.24.1
watchdog
httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0

View File

@ -170,7 +170,6 @@ class LLMKnowledgeChain(LLMChain):
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
except:
queries = [(line.split("")[0].strip(), line.split("")[1].strip()) for line in lines]
print(queries)
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
output = self._evaluate_expression(queries)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)

View File

@ -81,6 +81,8 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
# 知识库相关接口
mount_knowledge_routes(app)
# 摘要相关接口
mount_filename_summary_routes(app)
# LLM模型相关接口
app.post("/llm_model/list_running_models",
@ -230,6 +232,26 @@ def mount_knowledge_routes(app: FastAPI):
)(upload_temp_docs)
def mount_filename_summary_routes(app: FastAPI):
from server.knowledge_base.kb_summary_api import (summary_file_to_vector_store, recreate_summary_vector_store,
summary_doc_ids_to_vector_store)
app.post("/knowledge_base/kb_summary_api/summary_file_to_vector_store",
tags=["Knowledge kb_summary_api Management"],
summary="单个知识库根据文件名称摘要"
)(summary_file_to_vector_store)
app.post("/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store",
tags=["Knowledge kb_summary_api Management"],
summary="单个知识库根据doc_ids摘要",
response_model=BaseResponse,
)(summary_doc_ids_to_vector_store)
app.post("/knowledge_base/kb_summary_api/recreate_summary_vector_store",
tags=["Knowledge kb_summary_api Management"],
summary="重建单个知识库文件摘要"
)(recreate_summary_vector_store)
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,

View File

@ -43,7 +43,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = CustomAsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,

View File

@ -19,6 +19,7 @@ from server.callback_handler.conversation_callback_handler import ConversationCa
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
@ -34,18 +35,20 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
async def chat_iterator() -> AsyncIterable[str]:
nonlocal history
nonlocal history, max_tokens
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
memory = None
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
# 负责保存llm response到message db
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=query)
callbacks.append(conversation_callback)
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
@ -54,18 +57,24 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks=callbacks,
)
if not conversation_id:
if history: # 优先使用前端传入的历史消息
history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
else:
elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
prompt = get_prompt_template("llm_chat", "with_history")
chat_prompt = PromptTemplate.from_template(prompt)
# 根据conversation_id 获取message 列表进而拼凑 memory
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=model)
memory = ConversationBufferDBMemory(conversation_id=conversation_id,
llm=model,
message_limit=history_len)
else:
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)

View File

@ -6,7 +6,8 @@ from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, Optional
import asyncio
from langchain.prompts.chat import PromptTemplate
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template
@ -27,7 +28,11 @@ async def completion(query: str = Body(..., description="用户输入", examples
prompt_name: str = prompt_name,
echo: bool = echo,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_OpenAI(
model_name=model_name,
temperature=temperature,

View File

@ -113,7 +113,11 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
@ -121,7 +125,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
callbacks=[callback],
)
embed_func = EmbeddingsFunAdapter()
embeddings = embed_func.embed_query(query)
embeddings = await embed_func.aembed_query(query)
with memo_faiss_pool.acquire(knowledge_id) as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
docs = [x[0] for x in docs]

View File

@ -1,5 +1,6 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
@ -10,55 +11,74 @@ import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.utils import get_doc_path
import json
from pathlib import Path
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=2),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
request: Request = None,
):
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
SCORE_THRESHOLD,
description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右",
ge=0,
le=2
),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
None,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
async def knowledge_base_chat_iterator(
query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
docs = await run_in_threadpool(search_docs,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: ## 如果没有找到相关文档使用Empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
if len(docs) == 0: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
@ -76,14 +96,14 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
if stream:
async for token in callback.aiter():
@ -104,4 +124,4 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")
media_type="text/event-stream")

View File

@ -147,7 +147,11 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,

View File

@ -0,0 +1,28 @@
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
from server.db.base import Base
class SummaryChunkModel(Base):
"""
chunk summary模型用于存储file_doc中每个doc_id的chunk 片段
数据来源:
用户输入: 用户上传文件可填写文件的描述生成的file_doc中的doc_id存入summary_chunk中
程序自动切分 对file_doc表meta_data字段信息中存储的页码信息按每页的页码切分自定义prompt生成总结文本将对应页码关联的doc_id存入summary_chunk中
后续任务:
矢量库构建: 对数据库表summary_chunk中summary_context创建索引构建矢量库meta_data为矢量库的元数据doc_ids
语义关联 通过用户输入的描述自动切分的总结文本计算
语义相似度
"""
__tablename__ = 'summary_chunk'
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
kb_name = Column(String(50), comment='知识库名称')
summary_context = Column(String(255), comment='总结文本')
summary_id = Column(String(255), comment='总结矢量id')
doc_ids = Column(String(1024), comment="向量库id关联列表")
meta_data = Column(JSON, default={})
def __repr__(self):
return (f"<SummaryChunk(id='{self.id}', kb_name='{self.kb_name}', summary_context='{self.summary_context}',"
f" doc_ids='{self.doc_ids}', metadata='{self.metadata}')>")

View File

@ -0,0 +1,66 @@
from server.db.models.knowledge_metadata_model import SummaryChunkModel
from server.db.session import with_session
from typing import List, Dict
@with_session
def list_summary_from_db(session,
kb_name: str,
metadata: Dict = {},
) -> List[Dict]:
'''
列出某知识库chunk summary
返回形式[{"id": str, "summary_context": str, "doc_ids": str}, ...]
'''
docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
for k, v in metadata.items():
docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v))
return [{"id": x.id,
"summary_context": x.summary_context,
"summary_id": x.summary_id,
"doc_ids": x.doc_ids,
"metadata": x.metadata} for x in docs.all()]
@with_session
def delete_summary_from_db(session,
kb_name: str
) -> List[Dict]:
'''
删除知识库chunk summary并返回被删除的Dchunk summary
返回形式[{"id": str, "summary_context": str, "doc_ids": str}, ...]
'''
docs = list_summary_from_db(kb_name=kb_name)
query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
query.delete()
session.commit()
return docs
@with_session
def add_summary_to_db(session,
kb_name: str,
summary_infos: List[Dict]):
'''
将总结信息添加到数据库
summary_infos形式[{"summary_context": str, "doc_ids": str}, ...]
'''
for summary in summary_infos:
obj = SummaryChunkModel(
kb_name=kb_name,
summary_context=summary["summary_context"],
summary_id=summary["summary_id"],
doc_ids=summary["doc_ids"],
meta_data=summary["metadata"],
)
session.add(obj)
session.commit()
return True
@with_session
def count_summary_from_db(session, kb_name: str) -> int:
return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count()

View File

@ -3,6 +3,7 @@ from configs import EMBEDDING_MODEL, logger
from server.model_workers.base import ApiEmbeddingsParams
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from typing import Dict, List
@ -39,6 +40,32 @@ def embed_texts(
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
async def aembed_texts(
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=await embeddings.aembed_documents(texts))
if embed_model in list_online_embed_models(): # 使用在线API
return await run_in_threadpool(embed_texts,
texts=texts,
embed_model=embed_model,
to_query=to_query)
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型除了本地部署的Embedding模型也支持在线API({online_embed_models})提供的嵌入服务。"),

View File

@ -45,7 +45,7 @@ class _FaissPool(CachePool):
# create an empty vector store
embeddings = EmbeddingsFunAdapter(embed_model)
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
@ -82,7 +82,7 @@ class KBFaissPool(_FaissPool):
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
elif create:
# create an empty vector store
if not os.path.exists(vs_path):

View File

@ -26,8 +26,8 @@ from server.knowledge_base.utils import (
from typing import List, Union, Dict, Optional
from server.embeddings_api import embed_texts
from server.embeddings_api import embed_documents
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
def normalize(embeddings: List[List[float]]) -> np.ndarray:
@ -183,12 +183,22 @@ class KBService(ABC):
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = self.get_doc_by_ids([x["id"] for x in doc_infos])
docs = []
for x in doc_infos:
doc_info_s = self.get_doc_by_ids([x["id"]])
if doc_info_s is not None and doc_info_s != []:
# 处理非空的情况
doc_with_id = DocumentWithVSId(**doc_info_s[0].dict(), id=x["id"])
docs.append(doc_with_id)
else:
# 处理空的情况
# 可以选择跳过当前循环迭代或执行其他操作
pass
return docs
@abstractmethod
@ -394,12 +404,16 @@ class EmbeddingsFunAdapter(Embeddings):
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
# TODO: 暂不支持异步
# async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
# return normalize(await self.embeddings.aembed_documents(texts))
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
return normalize(embeddings).tolist()
# async def aembed_query(self, text: str) -> List[float]:
# return normalize(await self.embeddings.aembed_query(text))
async def aembed_query(self, text: str) -> List[float]:
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
def score_threshold_process(score_threshold, k, docs):

View File

@ -0,0 +1,79 @@
from typing import List
from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH)
from abc import ABC, abstractmethod
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
import os
import shutil
from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db
from langchain.docstore.document import Document
# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加
class KBSummaryService(ABC):
kb_name: str
embed_model: str
vs_path: str
kb_path: str
def __init__(self,
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL
):
self.kb_name = knowledge_base_name
self.embed_model = embed_model
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
def get_vs_path(self):
return os.path.join(self.get_kb_path(), "summary_vector_store")
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name="summary_vector_store",
embed_model=self.embed_model,
create=True)
def add_kb_summary(self, summary_combine_docs: List[Document]):
with self.load_vector_store().acquire() as vs:
ids = vs.add_documents(documents=summary_combine_docs)
vs.save_local(self.vs_path)
summary_infos = [{"summary_context": doc.page_content,
"summary_id": id,
"doc_ids": doc.metadata.get('doc_ids'),
"metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)]
status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos)
return status
def create_kb_summary(self):
"""
创建知识库chunk summary
:return:
"""
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
def drop_kb_summary(self):
"""
删除知识库chunk summary
:param kb_name:
:return:
"""
with kb_faiss_pool.atomic:
kb_faiss_pool.pop(self.kb_name)
shutil.rmtree(self.vs_path)
delete_summary_from_db(kb_name=self.kb_name)

View File

@ -0,0 +1,247 @@
from typing import List, Optional
from langchain.schema.language_model import BaseLanguageModel
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import (logger)
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
import sys
import asyncio
class SummaryAdapter:
_OVERLAP_SIZE: int
token_max: int
_separator: str = "\n\n"
chain: MapReduceDocumentsChain
def __init__(self, overlap_size: int, token_max: int,
chain: MapReduceDocumentsChain):
self._OVERLAP_SIZE = overlap_size
self.chain = chain
self.token_max = token_max
@classmethod
def form_summary(cls,
llm: BaseLanguageModel,
reduce_llm: BaseLanguageModel,
overlap_size: int,
token_max: int = 1300):
"""
获取实例
:param reduce_llm: 用于合并摘要的llm
:param llm: 用于生成摘要的llm
:param overlap_size: 重叠部分大小
:param token_max: 最大的chunk数量每个chunk长度小于token_max长度第一次生成摘要时大于token_max长度的摘要会报错
:return:
"""
# This controls how each document will be formatted. Specifically,
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
# The prompt here should take as an input variable the
# `document_variable_name`
prompt_template = (
"根据文本执行任务。以下任务信息"
"{task_briefing}"
"文本内容如下: "
"\r\n"
"{context}"
)
prompt = PromptTemplate(
template=prompt_template,
input_variables=["task_briefing", "context"]
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
# We now define how to combine these summaries
reduce_prompt = PromptTemplate.from_template(
"Combine these summaries: {context}"
)
reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
document_variable_name = "context"
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
token_max=token_max,
combine_documents_chain=combine_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
reduce_documents_chain=reduce_documents_chain,
# 返回中间步骤
return_intermediate_steps=True
)
return cls(overlap_size=overlap_size,
chain=chain,
token_max=token_max)
def summarize(self,
file_description: str,
docs: List[DocumentWithVSId] = []
) -> List[Document]:
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 同步调用协程代码
return loop.run_until_complete(self.asummarize(file_description=file_description,
docs=docs))
async def asummarize(self,
file_description: str,
docs: List[DocumentWithVSId] = []) -> List[Document]:
logger.info("start summary")
# TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题
# merge_docs = self._drop_overlap(docs)
# # 将merge_docs中的句子合并成一个文档
# text = self._join_docs(merge_docs)
# 根据段落于句子的分隔符将文档分成chunk每个chunk长度小于token_max长度
"""
这个过程分成两个部分
1. 对每个文档进行处理得到每个文档的摘要
map_results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast.
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
callbacks=callbacks,
)
2. 对每个文档的摘要进行合并得到最终的摘要return_intermediate_steps=True返回中间步骤
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
)
"""
summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
task_briefing="描述不同方法之间的接近度和相似性,"
"以帮助读者理解它们之间的关系。")
print(summary_combine)
print(summary_intermediate_steps)
# if len(summary_combine) == 0:
# # 为空重新生成,数量减半
# result_docs = [
# Document(page_content=question_result_key, metadata=docs[i].metadata)
# # This uses metadata from the docs, and the textual results from `results`
# for i, question_result_key in enumerate(
# summary_intermediate_steps["intermediate_steps"][
# :len(summary_intermediate_steps["intermediate_steps"]) // 2
# ])
# ]
# summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
# result_docs, token_max=self.token_max
# )
logger.info("end summary")
doc_ids = ",".join([doc.id for doc in docs])
_metadata = {
"file_description": file_description,
"summary_intermediate_steps": summary_intermediate_steps,
"doc_ids": doc_ids
}
summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
return [summary_combine_doc]
def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
"""
# 将文档中page_content句子叠加的部分去掉
:param docs:
:param separator:
:return:
"""
merge_docs = []
pre_doc = None
for doc in docs:
# 第一个文档直接添加
if len(merge_docs) == 0:
pre_doc = doc.page_content
merge_docs.append(doc.page_content)
continue
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
# 迭代递减pre_doc的长度每次迭代删除前面的字符
# 查询重叠部分直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
# 每次迭代删除前面的字符
pre_doc = pre_doc[1:]
if doc.page_content[:len(pre_doc)] == pre_doc:
# 删除下一个开头重叠的部分
merge_docs.append(doc.page_content[len(pre_doc):])
break
pre_doc = doc.page_content
return merge_docs
def _join_docs(self, docs: List[str]) -> Optional[str]:
text = self._separator.join(docs)
text = text.strip()
if text == "":
return None
else:
return text
if __name__ == '__main__':
docs = [
'梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
'梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
'使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
'值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
]
_OVERLAP_SIZE = 1
separator: str = "\n\n"
merge_docs = []
# 将文档中page_content句子叠加的部分去掉
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
pre_doc = None
for doc in docs:
# 第一个文档直接添加
if len(merge_docs) == 0:
pre_doc = doc
merge_docs.append(doc)
continue
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
# 迭代递减pre_doc的长度每次迭代删除前面的字符
# 查询重叠部分直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
# 每次迭代删除前面的字符
pre_doc = pre_doc[1:]
if doc[:len(pre_doc)] == pre_doc:
# 删除下一个开头重叠的部分
page_content = doc[len(pre_doc):]
merge_docs.append(page_content)
pre_doc = doc
break
# 将merge_docs中的句子合并成一个文档
text = separator.join(merge_docs)
text = text.strip()
print(text)

View File

@ -0,0 +1,220 @@
from fastapi import Body
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
OVERLAP_SIZE,
logger, log_verbose, )
from server.knowledge_base.utils import (list_files_from_folder)
from fastapi.responses import StreamingResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List, Optional
from server.knowledge_base.kb_summary.base import KBSummaryService
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
from configs import LLM_MODELS, TEMPERATURE
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
def recreate_summary_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
重建单个知识库文件摘要
:param max_tokens:
:param model_name:
:param temperature:
:param file_description:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
files = list_files_from_folder(knowledge_base_name)
i = 0
for i, file_name in enumerate(files):
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
return StreamingResponse(output(), media_type="text/event-stream")
def summary_file_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["test.pdf"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
单个知识库根据文件名称摘要
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param file_name:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f" {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"{file_name} 总结完成",
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
return StreamingResponse(output(), media_type="text/event-stream")
def summary_doc_ids_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
doc_ids: List = Body([], examples=[["uuid"]]),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
) -> BaseResponse:
"""
单个知识库根据doc_ids摘要
:param knowledge_base_name:
:param doc_ids:
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param vs_type:
:param embed_model:
:return:
"""
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists():
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
else:
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
doc_infos = kb.get_doc_by_ids(ids=doc_ids)
# doc_infos转换成DocumentWithVSId包装的对象
doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)]
docs = summary.summarize(file_description=file_description,
docs=doc_info_with_ids)
# 将docs转换成dict
resp_summarize = [{**doc.dict()} for doc in docs]
return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize})

View File

@ -12,6 +12,8 @@ from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.models.conversation_model import ConversationModel
from server.db.models.message_model import MessageModel
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
from server.db.repository.knowledge_metadata_repository import add_summary_to_db
from server.db.base import Base, engine
from server.db.session import session_scope
import os

View File

@ -0,0 +1,10 @@
from langchain.docstore.document import Document
class DocumentWithVSId(Document):
"""
矢量化后的文档
"""
id: str = None

View File

@ -71,7 +71,7 @@ def list_files_from_folder(kb_name: str):
for target_entry in target_it:
process_entry(target_entry)
elif entry.is_file():
result.append(entry.path)
result.append(os.path.relpath(entry.path, doc_path))
elif entry.is_dir():
with os.scandir(entry.path) as it:
for sub_entry in it:
@ -85,6 +85,7 @@ def list_files_from_folder(kb_name: str):
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"MHTMLLoader": ['.mhtml'],
"UnstructuredMarkdownLoader": ['.md'],
"JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"],
@ -106,6 +107,7 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredWordDocumentLoader": ['.docx', 'doc'],
"UnstructuredXMLLoader": ['.xml'],
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
"EverNoteLoader": ['.enex'],
"UnstructuredFileLoader": ['.txt'],
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
@ -311,6 +313,9 @@ class KnowledgeFile:
else:
docs = text_splitter.split_documents(docs)
if not docs:
return []
print(f"文档切分示例:{docs[0]}")
if zh_title_enhance:
docs = func_zh_title_enhance(docs)

View File

@ -1,4 +1,5 @@
import sys
import os
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
@ -19,16 +20,16 @@ class AzureWorker(ApiModelWorker):
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
data = dict(
messages=params.messages,
temperature=params.temperature,
max_tokens=params.max_tokens,
max_tokens=params.max_tokens if params.max_tokens else None,
stream=True,
)
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
@ -47,6 +48,7 @@ class AzureWorker(ApiModelWorker):
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
print(data)
for line in response.iter_lines():
if not line.strip() or "[DONE]" in line:
continue
@ -60,6 +62,7 @@ class AzureWorker(ApiModelWorker):
"error_code": 0,
"text": text
}
print(text)
else:
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
@ -91,4 +94,4 @@ if __name__ == "__main__":
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21008)
uvicorn.run(app, port=21008)

View File

@ -119,7 +119,9 @@ class MiniMaxWorker(ApiModelWorker):
with get_httpx_client() as client:
result = []
i = 0
for texts in params.texts[i:i+10]:
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i+batch_size]
data["texts"] = texts
r = client.post(url, headers=headers, json=data).json()
if embeddings := r.get("vectors"):
@ -137,7 +139,7 @@ class MiniMaxWorker(ApiModelWorker):
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
i += 10
i += batch_size
return {"code": 200, "data": embeddings}
def get_embeddings(self, params):

View File

@ -188,9 +188,11 @@ class QianFanWorker(ApiModelWorker):
with get_httpx_client() as client:
result = []
i = 0
for texts in params.texts[i:i+10]:
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i+batch_size]
resp = client.post(url, json={"input": texts}).json()
if "error_cdoe" in resp:
if "error_code" in resp:
data = {
"code": resp["error_code"],
"msg": resp["error_msg"],
@ -206,7 +208,7 @@ class QianFanWorker(ApiModelWorker):
else:
embeddings = [x["embedding"] for x in resp.get("data", [])]
result += embeddings
i += 10
i += batch_size
return {"code": 200, "data": result}
# TODO: qianfan支持续写模型

View File

@ -40,11 +40,10 @@ def get_ChatOpenAI(
verbose: bool = True,
**kwargs: Any,
) -> ChatOpenAI:
## 以下模型是Langchain原生支持的模型这些模型不会走Fschat封装
config_models = list_config_llm_models()
## 非Langchain原生支持的模型走Fschat封装
config = get_model_worker_config(model_name)
if model_name == "openai-api":
model_name = config.get("model_name")
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
@ -57,10 +56,8 @@ def get_ChatOpenAI(
openai_proxy=config.get("openai_proxy"),
**kwargs
)
return model
def get_OpenAI(
model_name: str,
temperature: float,
@ -71,67 +68,22 @@ def get_OpenAI(
verbose: bool = True,
**kwargs: Any,
) -> OpenAI:
## 以下模型是Langchain原生支持的模型这些模型不会走Fschat封装
config_models = list_config_llm_models()
if model_name in config_models.get("langchain", {}):
config = config_models["langchain"][model_name]
if model_name == "Azure-OpenAI":
model = AzureOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
deployment_name=config.get("deployment_name"),
model_version=config.get("model_version"),
openai_api_type=config.get("openai_api_type"),
openai_api_base=config.get("api_base_url"),
openai_api_version=config.get("api_version"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
echo=echo,
)
elif model_name == "OpenAI":
model = OpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
openai_api_base=config.get("api_base_url"),
openai_api_key=config.get("api_key"),
openai_proxy=config.get("openai_proxy"),
temperature=temperature,
max_tokens=max_tokens,
echo=echo,
)
elif model_name == "Anthropic":
model = Anthropic(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
model_name=config.get("model_name"),
anthropic_api_key=config.get("api_key"),
echo=echo,
)
## TODO 支持其他的Langchain原生支持的模型
else:
## 非Langchain原生支持的模型走Fschat封装
config = get_model_worker_config(model_name)
model = OpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
echo=echo,
**kwargs
)
config = get_model_worker_config(model_name)
if model_name == "openai-api":
model_name = config.get("model_name")
model = OpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
echo=echo,
**kwargs
)
return model
@ -630,7 +582,7 @@ def get_httpx_client(
for host in os.environ.get("no_proxy", "").split(","):
if host := host.strip():
# default_proxies.update({host: None}) # Origin code
default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error
default_proxies.update({'all://' + host: None}) # PR 1838 fix, if not add 'all://', httpx will raise error
# merge default proxies with user provided proxies
if isinstance(proxies, str):
@ -714,7 +666,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
from configs.basic_config import BASE_TEMP_DIR
import tempfile
if id is not None: # 如果指定的临时目录已存在,直接返回
if id is not None: # 如果指定的临时目录已存在,直接返回
path = os.path.join(BASE_TEMP_DIR, id)
if os.path.isdir(path):
return path, id

View File

@ -36,6 +36,7 @@ from server.utils import (fschat_controller_address, fschat_model_worker_address
fschat_openai_api_address, set_httpx_config, get_httpx_client,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
from server.knowledge_base.migrate import create_tables
import argparse
from typing import Tuple, List, Dict
from configs import VERSION
@ -866,6 +867,8 @@ async def start_main_server():
logger.info("Process status: %s", p)
if __name__ == "__main__":
# 确保数据库表被创建
create_tables()
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()

View File

@ -0,0 +1,44 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
api_base_url = api_address()
kb = "samples"
file_name = "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/knowledge_base/samples/content/llm/大模型技术栈-实战与应用.md"
doc_ids = [
"357d580f-fdf7-495c-b58b-595a398284e8",
"c7338773-2e83-4671-b237-1ad20335b0f0",
"6da613d1-327d-466f-8c1a-b32e6f461f47"
]
def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summary_file_to_vector_store"):
url = api_base_url + api
print("\n文件摘要:")
r = requests.post(url, json={"knowledge_base_name": kb,
"file_name": file_name
}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk)
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/summary_doc_ids_to_vector_store"):
url = api_base_url + api
print("\n文件摘要:")
r = requests.post(url, json={"knowledge_base_name": kb,
"doc_ids": doc_ids
}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk)
assert isinstance(data, dict)
assert data["code"] == 200
print(data)

View File

@ -1,8 +1,11 @@
import streamlit as st
from webui_pages.utils import *
from streamlit_chatbox import *
from streamlit_modal import Modal
from datetime import datetime
import os
import re
import time
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from server.knowledge_base.utils import LOADER_DICT
@ -47,6 +50,53 @@ def upload_temp_docs(files, _api: ApiRequest) -> str:
return _api.upload_temp_docs(files).get("data", {}).get("id")
def parse_command(text: str, modal: Modal) -> bool:
'''
检查用户是否输入了自定义命令当前支持
/new {session_name}如果未提供名称默认为会话X
/del {session_name}如果未提供名称在会话数量>1的情况下删除当前会话
/clear {session_name}如果未提供名称默认清除当前会话
/help查看命令帮助
返回值输入的是命令返回True否则返回False
'''
if m := re.match(r"/([^\s]+)\s*(.*)", text):
cmd, name = m.groups()
name = name.strip()
conv_names = chat_box.get_chat_names()
if cmd == "help":
modal.open()
elif cmd == "new":
if not name:
i = 1
while True:
name = f"会话{i}"
if name not in conv_names:
break
i += 1
if name in st.session_state["conversation_ids"]:
st.error(f"该会话名称 “{name}” 已存在")
time.sleep(1)
else:
st.session_state["conversation_ids"][name] = uuid.uuid4().hex
st.session_state["cur_conv_name"] = name
elif cmd == "del":
name = name or st.session_state.get("cur_conv_name")
if len(conv_names) == 1:
st.error("这是最后一个会话,无法删除")
time.sleep(1)
elif not name or name not in st.session_state["conversation_ids"]:
st.error(f"无效的会话名称:“{name}")
time.sleep(1)
else:
st.session_state["conversation_ids"].pop(name, None)
chat_box.del_chat_name(name)
st.session_state["cur_conv_name"] = ""
elif cmd == "clear":
chat_box.reset_history(name=name or None)
return True
return False
def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.session_state.setdefault("conversation_ids", {})
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
@ -60,25 +110,15 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
)
chat_box.init_session()
# 弹出自定义命令帮助信息
modal = Modal("自定义命令", key="cmd_help", max_width="500")
if modal.is_open():
with modal.container():
cmds = [x for x in parse_command.__doc__.split("\n") if x.strip().startswith("/")]
st.write("\n\n".join(cmds))
with st.sidebar:
# 多会话
cols = st.columns([3, 1])
conv_name = cols[0].text_input("会话名称")
with cols[1]:
if st.button("添加"):
if not conv_name or conv_name in st.session_state["conversation_ids"]:
st.error("请指定有效的会话名称")
else:
st.session_state["conversation_ids"][conv_name] = uuid.uuid4().hex
st.session_state["cur_conv_name"] = conv_name
st.session_state["conv_name"] = ""
if st.button("删除"):
if not conv_name or conv_name not in st.session_state["conversation_ids"]:
st.error("请指定有效的会话名称")
else:
st.session_state["conversation_ids"].pop(conv_name, None)
st.session_state["cur_conv_name"] = ""
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
if st.session_state.get("cur_conv_name") in conv_names:
@ -236,7 +276,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter "
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 "
def on_feedback(
feedback,
@ -256,139 +296,142 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
}
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
history = get_messages_history(history_len)
chat_box.user_say(prompt)
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
message_id = ""
r = api.chat_chat(prompt,
history=history,
conversation_id=conversation_id,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
break
text += t.get("text", "")
chat_box.update_msg(text)
message_id = t.get("message_id", "")
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
st.rerun()
else:
history = get_messages_history(history_len)
chat_box.user_say(prompt)
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
message_id = ""
r = api.chat_chat(prompt,
history=history,
conversation_id=conversation_id,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
break
text += t.get("text", "")
chat_box.update_msg(text)
message_id = t.get("message_id", "")
metadata = {
"message_id": message_id,
}
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
chat_box.show_feedback(**feedback_kwargs,
key=message_id,
on_submit=on_feedback,
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
metadata = {
"message_id": message_id,
}
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
chat_box.show_feedback(**feedback_kwargs,
key=message_id,
on_submit=on_feedback,
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
elif dialogue_mode == "自定义Agent问答":
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
elif dialogue_mode == "自定义Agent问答":
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
chat_box.ai_say([
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐请更换支持Agent的模型获得更好的体验</span>\n\n\n",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
])
else:
chat_box.ai_say([
f"正在思考...",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
])
text = ""
ans = ""
for d in api.agent_chat(prompt,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature,
):
try:
d = json.loads(d)
except:
pass
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
if chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=1)
if chunk := d.get("final_answer"):
ans += chunk
chat_box.update_msg(ans, element_index=0)
if chunk := d.get("tools"):
text += "\n\n".join(d.get("tools", []))
chat_box.update_msg(text, element_index=1)
chat_box.update_msg(ans, element_index=0, streaming=False)
chat_box.update_msg(text, element_index=1, streaming=False)
elif dialogue_mode == "知识库问答":
chat_box.ai_say([
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐请更换支持Agent的模型获得更好的体验</span>\n\n\n",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
f"正在查询知识库 `{selected_kb}` ...",
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
])
else:
text = ""
for d in api.knowledge_base_chat(prompt,
knowledge_base_name=selected_kb,
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
elif dialogue_mode == "文件对话":
if st.session_state["file_chat_id"] is None:
st.error("请先上传文件再进行对话")
st.stop()
chat_box.ai_say([
f"正在思考...",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
])
text = ""
ans = ""
for d in api.agent_chat(prompt,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature,
):
try:
d = json.loads(d)
except:
pass
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
if chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=1)
if chunk := d.get("final_answer"):
ans += chunk
chat_box.update_msg(ans, element_index=0)
if chunk := d.get("tools"):
text += "\n\n".join(d.get("tools", []))
chat_box.update_msg(text, element_index=1)
chat_box.update_msg(ans, element_index=0, streaming=False)
chat_box.update_msg(text, element_index=1, streaming=False)
elif dialogue_mode == "知识库问答":
chat_box.ai_say([
f"正在查询知识库 `{selected_kb}` ...",
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
])
text = ""
for d in api.knowledge_base_chat(prompt,
knowledge_base_name=selected_kb,
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
elif dialogue_mode == "文件对话":
if st.session_state["file_chat_id"] is None:
st.error("请先上传文件再进行对话")
st.stop()
chat_box.ai_say([
f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
])
text = ""
for d in api.file_chat(prompt,
knowledge_id=st.session_state["file_chat_id"],
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
elif dialogue_mode == "搜索引擎问答":
chat_box.ai_say([
f"正在执行 `{search_engine}` 搜索...",
Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
])
text = ""
for d in api.search_engine_chat(prompt,
search_engine_name=search_engine,
top_k=se_top_k,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature,
split_result=se_top_k > 1):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
text = ""
for d in api.file_chat(prompt,
knowledge_id=st.session_state["file_chat_id"],
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
elif dialogue_mode == "搜索引擎问答":
chat_box.ai_say([
f"正在执行 `{search_engine}` 搜索...",
Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
])
text = ""
for d in api.search_engine_chat(prompt,
search_engine_name=search_engine,
top_k=se_top_k,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature,
split_result=se_top_k > 1):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
if st.session_state.get("need_rerun"):
st.session_state["need_rerun"] = False

View File

@ -34,14 +34,19 @@ def config_aggrid(
use_checkbox=use_checkbox,
# pre_selected_rows=st.session_state.get("selected_rows", [0]),
)
gb.configure_pagination(
enabled=True,
paginationAutoPageSize=False,
paginationPageSize=10
)
return gb
def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]:
'''
"""
check whether a doc file exists in local knowledge base folder.
return the file's name and path if it exists.
'''
"""
if selected_rows:
file_name = selected_rows[0]["file_name"]
file_path = get_file_path(kb, file_name)

View File

@ -85,7 +85,7 @@ class ApiRequest:
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
print(kwargs)
# print(kwargs)
if stream:
return self.client.stream("POST", url, data=data, json=json, **kwargs)
else:
@ -134,7 +134,7 @@ class ApiRequest:
if as_json:
try:
data = json.loads(chunk)
pprint(data, depth=1)
# pprint(data, depth=1)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
@ -166,7 +166,7 @@ class ApiRequest:
if as_json:
try:
data = json.loads(chunk)
pprint(data, depth=1)
# pprint(data, depth=1)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
@ -276,8 +276,8 @@ class ApiRequest:
"max_tokens": max_tokens,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/fastchat",
@ -288,23 +288,25 @@ class ApiRequest:
return self._httpx_stream2generator(response)
def chat_chat(
self,
query: str,
conversation_id: str = None,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
**kwargs,
self,
query: str,
conversation_id: str = None,
history_len: int = -1,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
**kwargs,
):
'''
对应api.py/chat/chat接口 #TODO: 考虑是否返回json
对应api.py/chat/chat接口
'''
data = {
"query": query,
"conversation_id": conversation_id,
"history_len": history_len,
"history": history,
"stream": stream,
"model_name": model,
@ -313,8 +315,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response, as_json=True)
@ -342,8 +344,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response)
@ -377,8 +379,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/knowledge_base_chat",
@ -452,8 +454,8 @@ class ApiRequest:
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/file_chat",
@ -491,8 +493,8 @@ class ApiRequest:
"split_result": split_result,
}
print(f"received input message:")
pprint(data)
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/search_engine_chat",