* 修复Azure 不设置Max token的bug

* 重写agent

1. 修改Agent实现方式,支持多参数,仅剩 ChatGLM3-6b和 OpenAI GPT4 支持,剩余模型将在暂时缺席Agent功能
2. 删除agent_chat 集成到llm_chat中
3. 重写大部分工具,适应新Agent

* 更新架构

* 删除web_chat,自动融合

* 移除所有聊天,都变成Agent控制

* 更新配置文件

* 更新配置模板和提示词

* 更改参数选择bug
This commit is contained in:
zR 2023-12-05 17:17:53 +08:00 committed by liunux4odoo
parent 808ca227c5
commit 253168a187
52 changed files with 862 additions and 2293 deletions

View File

@ -1,22 +0,0 @@
from server.utils import get_ChatOpenAI
from configs.model_config import LLM_MODELS, TEMPERATURE
from langchain.chains import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
model = get_ChatOpenAI(model_name=LLM_MODELS[0], temperature=TEMPERATURE)
human_prompt = "{input}"
human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
chat_prompt = ChatPromptTemplate.from_messages(
[("human", "我们来玩成语接龙,我先来,生龙活虎"),
("ai", "虎头虎脑"),
("human", "{input}")])
chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True)
print(chain({"input": "恼羞成怒"}))

View File

@ -1,14 +1,6 @@
import os
# 可以指定一个绝对路径统一存放所有的Embedding和LLM模型。
# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录。
# 如果模型目录名称和 MODEL_PATH 中的 key 或 value 相同,程序会自动检测加载,无需修改 MODEL_PATH 中的路径。
MODEL_ROOT_PATH = ""
# 选用的 Embedding 名称
EMBEDDING_MODEL = "bge-large-zh-v1.5"
# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
EMBEDDING_MODEL = "bge-large-zh-v1.5" # bge-large-zh
EMBEDDING_DEVICE = "auto"
# 选用的reranker模型
@ -21,65 +13,170 @@ RERANKER_MAX_LENGTH = 1024
EMBEDDING_KEYWORD_FILE = "keywords.txt"
EMBEDDING_MODEL_OUTPUT_PATH = "output"
# 要运行的 LLM 名称,可以包括本地模型和在线模型。列表中本地模型将在启动项目时全部加载。
# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
# 在这里我们使用目前主流的两个离线模型其中chatglm3-6b 为默认加载模型。
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
LLM_MODEL_CONFIG = {
# 意图识别不需要输出,模型后台知道就行
"preprocess_model": {
"zhipu-api": {
"temperature": 0.4,
"max_tokens": 2048,
"history_len": 100,
"prompt_name": "default",
"callbacks": False
},
},
"llm_model": {
"chatglm3-6b": {
"temperature": 0.9,
"max_tokens": 4096,
"history_len": 3,
"prompt_name": "default",
"callbacks": True
},
"zhipu-api": {
"temperature": 0.9,
"max_tokens": 4000,
"history_len": 5,
"prompt_name": "default",
"callbacks": True
},
"Qwen-1_8B-Chat": {
"temperature": 0.4,
"max_tokens": 2048,
"history_len": 100,
"prompt_name": "default",
"callbacks": False
},
},
"action_model": {
"chatglm3-6b": {
"temperature": 0.01,
"max_tokens": 4096,
"prompt_name": "ChatGLM3",
"callbacks": True
},
"zhipu-api": {
"temperature": 0.01,
"max_tokens": 2096,
"history_len": 5,
"prompt_name": "ChatGLM3",
"callbacks": True
},
},
"postprocess_model": {
"chatglm3-6b": {
"temperature": 0.01,
"max_tokens": 4096,
"prompt_name": "default",
"callbacks": True
}
},
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
Agent_MODEL = None
}
TOOL_CONFIG = {
"search_local_knowledgebase": {
"use": True,
"top_k": 10,
"score_threshold": 1,
"conclude_prompt": {
"with_result":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"'
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"without_result":
'请你根据我的提问回答我的问题:\n'
'{{ question }}\n'
'请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n',
}
},
"search_internet": {
"use": True,
"search_engine_name": "bing",
"search_engine_config":
{
"bing": {
"result_len": 3,
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
"bing_key": "",
},
"metaphor": {
"result_len": 3,
"metaphor_api_key": "",
"split_result": False,
"chunk_size": 500,
"chunk_overlap": 0,
},
"duckduckgo": {
"result_len": 3
}
},
"top_k": 10,
"verbose": "Origin",
"conclude_prompt":
"<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
"</指令>\n<已知信息>{{ context }}</已知信息>\n"
"<问题>\n"
"{{ question }}\n"
"</问题>\n"
},
"arxiv": {
"use": True,
},
"shell": {
"use": True,
},
"weather_check": {
"use": True,
"api-key": "",
},
"search_youtube": {
"use": False,
},
"wolfram": {
"use": False,
},
"calculate": {
"use": False,
},
}
# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
LLM_DEVICE = "auto"
HISTORY_LEN = 3
MAX_TOKENS = 2048
TEMPERATURE = 0.7
ONLINE_LLM_MODEL = {
"openai-api": {
"model_name": "gpt-4",
"model_name": "gpt-4-1106-preview",
"api_base_url": "https://api.openai.com/v1",
"api_key": "",
"openai_proxy": "",
},
# 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": {
"api_key": "",
"version": "glm-4",
"version": "chatglm_turbo",
"provider": "ChatGLMWorker",
},
# 具体注册及api key获取请前往 https://api.minimax.chat/
"minimax-api": {
"group_id": "",
"api_key": "",
"is_pro": False,
"provider": "MiniMaxWorker",
},
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
"xinghuo-api": {
"APPID": "",
"APISecret": "",
"api_key": "",
"version": "v3.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.5","v3.0", "v2.0", "v1.5"
"version": "v3.0",
"provider": "XingHuoWorker",
},
# 百度千帆 API申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
"qianfan-api": {
"version": "ERNIE-Bot", # 注意大小写。当前支持 "ERNIE-Bot" 或 "ERNIE-Bot-turbo" 更多的见官方文档。
"version_url": "", # 也可以不填写version直接填写在千帆申请模型发布的API地址
"version": "ernie-bot-4",
"version_url": "",
"api_key": "",
"secret_key": "",
"provider": "QianFanWorker",
},
# 火山方舟 API文档参考 https://www.volcengine.com/docs/82379
"fangzhou-api": {
"version": "chatglm-6b-model",
"version_url": "",
@ -87,28 +184,22 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "FangZhouWorker",
},
# 阿里云通义千问 API文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
"qwen-api": {
"version": "qwen-max",
"api_key": "",
"provider": "QwenWorker",
"embed_model": "text-embedding-v1" # embedding 模型名称
},
# 百川 API申请方式请参考 https://www.baichuan-ai.com/home#api-enter
"baichuan-api": {
"version": "Baichuan2-53B",
"api_key": "",
"secret_key": "",
"provider": "BaiChuanWorker",
},
# Azure API
"azure-api": {
"deployment_name": "", # 部署容器的名字
"resource_name": "", # https://{resource_name}.openai.azure.com/openai/ 填写resource_name的部分其他部分不要填写
"api_version": "", # API的版本不是模型版本
"deployment_name": "",
"resource_name": "",
"api_version": "2023-07-01-preview",
"api_key": "",
"provider": "AzureWorker",
},
@ -153,50 +244,33 @@ MODEL_PATH = {
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh": "/media/zr/Data/Models/Embedding/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
"bge-m3": "BAAI/bge-m3",
"bge-large-zh-v1.5": "/Models/bge-large-zh-v1.5",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
"text-embedding-ada-002": "your OPENAI_API_KEY",
"nlp_gte_sentence-embedding_chinese-large": "/Models/nlp_gte_sentence-embedding_chinese-large",
"text-embedding-ada-002": "Just write your OpenAI key like "sk-o3IGBhC9g8AiFvTGWVKsT*****" ",
},
"llm_model": {
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b": "/Models/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat",
"Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin",
"Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat",
"Yi-34B-Chat": "/data/share/models/Yi-34B-Chat",
"BlueLM-7B-Chat": "/Models/BlueLM-7B-Chat",
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
"baichuan2-13b": "/media/zr/Data/Models/LLM/Baichuan2-13B-Chat",
"baichuan2-7b": "/media/zr/Data/Models/LLM/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
# Qwen1.5 模型 VLLM可能出现问题
"Qwen1.5-0.5B-Chat": "Qwen/Qwen1.5-0.5B-Chat",
"Qwen1.5-1.8B-Chat": "Qwen/Qwen1.5-1.8B-Chat",
"Qwen1.5-4B-Chat": "Qwen/Qwen1.5-4B-Chat",
"Qwen1.5-7B-Chat": "Qwen/Qwen1.5-7B-Chat",
"Qwen1.5-14B-Chat": "Qwen/Qwen1.5-14B-Chat",
"Qwen1.5-72B-Chat": "Qwen/Qwen1.5-72B-Chat",
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat",
"baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat",
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
@ -235,42 +309,43 @@ MODEL_PATH = {
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
"koala": "young-geng/koala",
"mpt-7b": "mosaicml/mpt-7b",
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
"mpt-30b": "mosaicml/mpt-30b",
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"Qwen-1_8B-Chat":"Qwen/Qwen-1_8B-Chat"
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", # 确保已经安装了auto-gptq optimum flash-attn
"Qwen-14B-Chat-Int4": "/media/zr/Data/Models/LLM/Qwen-14B-Chat-Int4", # 确保已经安装了auto-gptq optimum flash-attn
},
"reranker": {
"bge-reranker-large": "BAAI/bge-reranker-large",
"bge-reranker-base": "BAAI/bge-reranker-base",
}
}
# 通常情况下不需要更改以下内容
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 使用VLLM可能导致模型推理能力下降无法完成Agent任务
VLLM_MODEL_DICT = {
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
'chatglm2-6b': 'THUDM/chatglm2-6b',
'chatglm2-6b-32k': 'THUDM/chatglm2-6b-32k',
'chatglm3-6b': 'THUDM/chatglm3-6b',
'chatglm3-6b-32k': 'THUDM/chatglm3-6b-32k',
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
@ -301,14 +376,13 @@ VLLM_MODEL_DICT = {
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
}
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
SUPPORT_AGENT_MODEL = [
"openai-api", # GPT4 模型
"qwen-api", # Qwen Max模型
"zhipu-api", # 智谱AI GLM4模型
"Qwen", # 所有Qwen系列本地模型
"chatglm3-6b",
"internlm2-chat-20b",
"Orion-14B-Chat-Plugin",
]
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
}

View File

@ -1,26 +1,19 @@
# prompt模板使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
# 本配置文件支持热加载修改prompt模板后无需重启服务。
# LLM对话支持的变量
# - input: 用户输入内容
# 知识库和搜索引擎对话支持的变量:
# - context: 从检索结果拼接的知识文本
# - question: 用户提出的问题
# Agent对话支持的变量
# - tools: 可用的工具列表
# - tool_names: 可用的工具名称列表
# - history: 用户和Agent的对话历史
# - input: 用户输入内容
# - agent_scratchpad: Agent的思维记录
PROMPT_TEMPLATES = {
"llm_chat": {
"preprocess_model": {
"default":
'请你根据我的描述和我们对话的历史来判断本次跟我交流是否需要使用工具还是可以直接凭借你的知识或者历史记录跟我对话。你只要回答一个数字。1 或者 01代表需要使用工具0代表不需要使用工具。\n'
'以下几种情况要使用工具,请返回1\n'
'1. 实时性的问题,例如天气,日期,地点等信息\n'
'2. 需要数学计算的问题\n'
'3. 需要查询数据,地点等精确数据\n'
'4. 需要行业知识的问题\n'
'<question>'
'{input}'
'</question>'
},
"llm_model": {
"default":
'{{ input }}',
"with_history":
'The following is a friendly conversation between a human and an AI. '
'The AI is talkative and provides lots of specific details from its context. '
@ -29,72 +22,42 @@ PROMPT_TEMPLATES = {
'{history}\n'
'Human: {input}\n'
'AI:',
"py":
'你是一个聪明的代码助手请你给我写出简单的py代码。 \n'
'{{ input }}',
},
"knowledge_base_chat": {
"default":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,'
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"text":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"empty": # 搜不到知识库的时候使用
'请你回答我的问题:\n'
'{{ question }}\n\n',
},
"search_engine_chat": {
"default":
'<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。'
'如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"search":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
},
"agent_chat": {
"default":
'Answer the following questions as best you can. If it is in order, you can use some tools appropriately. '
'You have access to the following tools:\n\n'
'{tools}\n\n'
'Use the following format:\n'
'Question: the input question you must answer1\n'
'Thought: you should always think about what to do and what tools to use.\n'
'Action: the action to take, should be one of [{tool_names}]\n'
'Action Input: the input to the action\n'
"action_model": {
"GPT-4":
'Answer the following questions as best you can. You have access to the following tools:\n'
'The way you use the tools is by specifying a json blob.\n'
'Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).\n'
'The only values that should be in the "action" field are: {tool_names}\n'
'The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:\n'
'```\n\n'
'{{{{\n'
' "action": $TOOL_NAME,\n'
' "action_input": $INPUT\n'
'}}}}\n'
'```\n\n'
'ALWAYS use the following format:\n'
'Question: the input question you must answer\n'
'Thought: you should always think about what to do\n'
'Action:\n'
'```\n\n'
'$JSON_BLOB'
'```\n\n'
'Observation: the result of the action\n'
'... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n'
'... (this Thought/Action/Observation can repeat N times)\n'
'Thought: I now know the final answer\n'
'Final Answer: the final answer to the original input question\n'
'Begin!\n\n'
'history: {history}\n\n'
'Question: {input}\n\n'
'Thought: {agent_scratchpad}\n',
'Begin! Reminder to always use the exact characters `Final Answer` when responding.\n'
'history: {history}\n'
'Question:{input}\n'
'Thought:{agent_scratchpad}\n',
"ChatGLM3":
'You can answer using the tools, or answer directly using your knowledge without using the tools. '
'Respond to the human as helpfully and accurately as possible.\n'
'You can answer using the tools.Respond to the human as helpfully and accurately as possible.\n'
'You have access to the following tools:\n'
'{tools}\n'
'Use a json blob to specify a tool by providing an action key (tool name) '
'Use a json blob to specify a tool by providing an action key (tool name)\n'
'and an action_input key (tool input).\n'
'Valid "action" values: "Final Answer" or [{tool_names}]'
'Valid "action" values: "Final Answer" or [{tool_names}]\n'
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
'```\n'
'{{{{\n'
@ -118,10 +81,13 @@ PROMPT_TEMPLATES = {
' "action": "Final Answer",\n'
' "action_input": "Final response to human"\n'
'}}}}\n'
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. '
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary.\n'
'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n'
'history: {history}\n\n'
'Question: {input}\n\n'
'Thought: {agent_scratchpad}',
'Thought: {agent_scratchpad}\n',
},
"postprocess_model": {
"default": "{{input}}",
}
}

View File

@ -57,7 +57,7 @@ FSCHAT_MODEL_WORKERS = {
# "awq_ckpt": None,
# "awq_wbits": 16,
# "awq_groupsize": -1,
# "model_names": LLM_MODELS,
# "model_names": None,
# "conv_template": None,
# "limit_worker_concurrency": 5,
# "stream_interval": 2,

View File

@ -1,29 +0,0 @@
# 实现基于ES的数据插入、检索、删除、更新
```shell
author: 唐国梁Tommy
e-mail: flytang186@qq.com
如果遇到任何问题,可以与我联系,我这边部署后服务是没有问题的。
```
## 第1步ES docker部署
```shell
docker network create elastic
docker run -id --name elasticsearch --net elastic -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "xpack.security.enabled=false" -e "xpack.security.http.ssl.enabled=false" -t docker.elastic.co/elasticsearch/elasticsearch:8.8.2
```
### 第2步Kibana docker部署
**注意Kibana版本与ES保持一致**
```shell
docker pull docker.elastic.co/kibana/kibana:{version}
docker run --name kibana --net elastic -p 5601:5601 docker.elastic.co/kibana/kibana:{version}
```
### 第3步核心代码
```shell
1. 核心代码路径
server/knowledge_base/kb_service/es_kb_service.py
2. 需要在 configs/model_config.py 中 配置 ES参数IP PORT
```

View File

@ -1,11 +1,7 @@
torch==2.1.2
torchvision==0.16.2
torchaudio==2.1.2
xformers==0.0.23.post1
transformers==4.37.2
sentence_transformers==2.2.2
langchain==0.0.354
langchain-experimental==0.0.47
# API requirements
langchain>=0.0.346
langchain-experimental>=0.0.42
pydantic==1.10.13
fschat==0.2.35
openai==1.9.0

View File

@ -1,11 +1,7 @@
torch~=2.1.2
torchvision~=0.16.2
torchaudio~=2.1.2
xformers>=0.0.23.post1
transformers==4.37.2
sentence_transformers==2.2.2
langchain==0.0.354
langchain-experimental==0.0.47
# API requirements
langchain>=0.0.346
langchain-experimental>=0.0.42
pydantic==1.10.13
fschat==0.2.35
openai~=1.9.0

View File

@ -1,4 +0,0 @@
from .model_contain import *
from .callbacks import *
from .custom_template import *
from .tools import *

View File

@ -0,0 +1 @@
from .glm3_agent import initialize_glm3_agent

View File

@ -3,6 +3,14 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file
"""
from __future__ import annotations
import yaml
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from typing import Any, List, Sequence, Tuple, Optional, Union
import os
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
import json
import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
@ -22,6 +30,7 @@ from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from pydantic.schema import model_schema
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__)
@ -42,8 +51,9 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
if "tool_call" in text:
action_end = text.find("```")
action = text[:action_end].strip()
params_str_start = text.find("(") + 1
params_str_end = text.rfind(")")
params_str_end = text.find(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
@ -225,4 +235,4 @@ def initialize_glm3_agent(
memory=memory,
tags=tags_,
**kwargs,
)
)

View File

@ -1,67 +0,0 @@
from __future__ import annotations
from langchain.agents import Tool, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from typing import List
from langchain.schema import AgentAction, AgentFinish
from configs import SUPPORT_AGENT_MODEL
from server.agent import model_container
class CustomPromptTemplate(StringPromptTemplate):
template: str
tools: List[Tool]
def format(self, **kwargs) -> str:
intermediate_steps = kwargs.pop("intermediate_steps")
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
kwargs["agent_scratchpad"] = thoughts
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
return self.template.format(**kwargs)
class CustomOutputParser(AgentOutputParser):
begin: bool = False
def __init__(self):
super().__init__()
self.begin = True
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
self.begin = False
stop_words = ["Observation:"]
min_index = len(llm_output)
for stop_word in stop_words:
index = llm_output.find(stop_word)
if index != -1 and index < min_index:
min_index = index
llm_output = llm_output[:min_index]
if "Final Answer:" in llm_output:
self.begin = True
return AgentFinish(
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
log=llm_output,
)
parts = llm_output.split("Action:")
if len(parts) < 2:
return AgentFinish(
return_values={"output": f"调用agent工具失败该回答为大模型自身能力的回答:\n\n `{llm_output}`"},
log=llm_output,
)
action = parts[1].split("Action Input:")[0].strip()
action_input = parts[1].split("Action Input:")[1].strip()
try:
ans = AgentAction(
tool=action,
tool_input=action_input.strip(" ").strip('"'),
log=llm_output
)
return ans
except:
return AgentFinish(
return_values={"output": f"调用agent失败: `{llm_output}`"},
log=llm_output,
)

View File

@ -1,6 +0,0 @@
class ModelContainer:
def __init__(self):
self.MODEL = None
self.DATABASE = None
model_container = ModelContainer()

View File

@ -1,11 +0,0 @@
## 导入所有的工具类
from .search_knowledgebase_simple import search_knowledgebase_simple
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
from .calculate import calculate, CalculatorInput
from .weather_check import weathercheck, WeatherInput
from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput
from .search_youtube import search_youtube, YoutubeInput
from .arxiv import arxiv, ArxivInput

View File

@ -1,76 +0,0 @@
from langchain.prompts import PromptTemplate
from langchain.chains import LLMMathChain
from server.agent import model_container
from pydantic import BaseModel, Field
_PROMPT_TEMPLATE = """
将数学问题翻译成可以使用Python的numexpr库执行的表达式使用运行此代码的输出来回答问题
问题: ${{包含数学问题的问题}}
```text
${{解决问题的单行数学表达式}}
```
...numexpr.evaluate(query)...
```output
${{运行代码的输出}}
```
答案: ${{答案}}
这是两个例子
问题: 37593 * 67是多少
```text
37593 * 67
```
...numexpr.evaluate("37593 * 67")...
```output
2518731
答案: 2518731
问题: 37593的五次方根是多少
```text
37593**(1/5)
```
...numexpr.evaluate("37593**(1/5)")...
```output
8.222831614237718
答案: 8.222831614237718
问题: 2的平方是多少
```text
2 ** 2
```
...numexpr.evaluate("2 ** 2")...
```output
4
答案: 4
现在这是我的问题
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)
class CalculatorInput(BaseModel):
query: str = Field()
def calculate(query: str):
model = model_container.MODEL
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_math.run(query)
return ans
if __name__ == "__main__":
result = calculate("2的三次方")
print("答案:",result)

View File

@ -1,37 +0,0 @@
import json
from server.chat.search_engine_chat import search_engine_chat
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
import asyncio
from server.agent import model_container
from pydantic import BaseModel, Field
async def search_engine_iter(query: str):
response = await search_engine_chat(query=query,
search_engine_name="bing", # 这里切换搜索引擎
model_name=model_container.MODEL.model_name,
temperature=0.01, # Agent 搜索互联网的时候温度设置为0.01
history=[],
top_k = VECTOR_SEARCH_TOP_K,
max_tokens= MAX_TOKENS,
prompt_name = "default",
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents = data["answer"]
docs = data["docs"]
return contents
def search_internet(query: str):
return asyncio.run(search_engine_iter(query))
class SearchInternetInput(BaseModel):
location: str = Field(description="Query for Internet search")
if __name__ == "__main__":
result = search_internet("今天星期几")
print("答案:",result)

View File

@ -1,287 +0,0 @@
from __future__ import annotations
import json
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from typing import List, Any, Optional
from langchain.prompts import PromptTemplate
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio
from server.agent import model_container
from pydantic import BaseModel, Field
async def search_knowledge_base_iter(database: str, query: str) -> str:
response = await knowledge_base_chat(query=query,
knowledge_base_name=database,
model_name=model_container.MODEL.model_name,
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="default",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents += data["answer"]
docs = data["docs"]
return contents
async def search_knowledge_multiple(queries) -> List[str]:
# queries 应该是一个包含多个 (database, query) 元组的列表
tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
results = await asyncio.gather(*tasks)
# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
combined_results = []
for (database, _), result in zip(queries, results):
message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
combined_results.append(message)
return combined_results
def search_knowledge(queries) -> str:
responses = asyncio.run(search_knowledge_multiple(queries))
# 输出每个整合的查询结果
contents = ""
for response in responses:
contents += response + "\n\n"
return contents
_PROMPT_TEMPLATE = """
用户会提出一个需要你查询知识库的问题你应该对问题进行理解和拆解并在知识库中查询相关的内容
对于每个知识库你输出的内容应该是一个一行的字符串这行字符串包含知识库名称和查询内容中间用逗号隔开不要有多余的文字和符号你可以同时查询多个知识库下面这个例子就是同时查询两个知识库的内容
例子:
robotic,机器人男女比例是多少
bigdata,大数据的就业情况如何
这些数据库是你能访问的冒号之前是他们的名字冒号之后是他们的功能你应该参考他们的功能来帮助你思考
{database_names}
你的回答格式应该按照下面的内容请注意```text 等标记都必须输出这是我用来提取答案的标记
不要输出中文的逗号不要输出引号
Question: ${{用户的问题}}
```text
${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号不要输出引号}}
```output
数据库查询的结果
现在我们开始作答
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question", "database_names"],
template=_PROMPT_TEMPLATE,
)
class LLMKnowledgeChain(LLMChain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
database_names: Dict[str, str] = None
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, queries) -> str:
try:
output = search_knowledge(queries)
except Exception as e:
output = "输入的信息有误或不存在知识库,错误信息如下:\n"
return output + str(e)
return output
def _process_llm_result(
self,
llm_output: str,
run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1).strip()
cleaned_input_str = (expression.replace("\"", "").replace("", "").
replace("", "").replace("```", "").strip())
lines = cleaned_input_str.split("\n")
# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
try:
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
except:
queries = [(line.split("")[0].strip(), line.split("")[1].strip()) for line in lines]
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
output = self._evaluate_expression(queries)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对:\n {llm_output}"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1).strip()
cleaned_input_str = (
expression.replace("\"", "").replace("", "").replace("", "").replace("```", "").strip())
lines = cleaned_input_str.split("\n")
try:
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
except:
queries = [(line.split("")[0].strip(), line.split("")[1].strip()) for line in lines]
await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
verbose=self.verbose)
output = self._evaluate_expression(queries)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
self.database_names = model_container.DATABASE
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = self.llm_chain.predict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
self.database_names = model_container.DATABASE
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = await self.llm_chain.apredict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
@property
def _chain_type(self) -> str:
return "llm_knowledge_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMKnowledgeChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def search_knowledgebase_complex(query: str):
model = model_container.MODEL
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_knowledge.run(query)
return ans
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="The query to be searched")
if __name__ == "__main__":
result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
print(result)
# 这是一个正常的切割
# queries = [
# ("bigdata", "大数据专业的男女比例"),
# ("robotic", "机器人专业的优势")
# ]
# result = search_knowledge(queries)
# print(result)

View File

@ -1,234 +0,0 @@
from __future__ import annotations
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from typing import List, Any, Optional
from langchain.prompts import PromptTemplate
import sys
import os
import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio
from server.agent import model_container
from pydantic import BaseModel, Field
async def search_knowledge_base_iter(database: str, query: str):
response = await knowledge_base_chat(query=query,
knowledge_base_name=database,
model_name=model_container.MODEL.model_name,
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="knowledge_base_chat",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents += data["answer"]
docs = data["docs"]
return contents
_PROMPT_TEMPLATE = """
用户会提出一个需要你查询知识库的问题你应该按照我提供的思想进行思考
Question: ${{用户的问题}}
这些数据库是你能访问的冒号之前是他们的名字冒号之后是他们的功能
{database_names}
你的回答格式应该按照下面的内容请注意格式内的```text 等标记都必须输出这是我用来提取答案的标记
```text
${{知识库的名称}}
```
```output
数据库查询的结果
```
答案: ${{答案}}
现在这是我的问题
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question", "database_names"],
template=_PROMPT_TEMPLATE,
)
class LLMKnowledgeChain(LLMChain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
database_names: Dict[str, str] = model_container.DATABASE
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, dataset, query) -> str:
try:
output = asyncio.run(search_knowledge_base_iter(dataset, query))
except Exception as e:
output = "输入的信息有误或不存在知识库"
return output
return output
def _process_llm_result(
self,
llm_output: str,
llm_input: str,
run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
database = text_match.group(1).strip()
output = self._evaluate_expression(database, llm_input)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对: {llm_output}"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = self.llm_chain.predict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
llm_output = await self.llm_chain.apredict(
database_names=data_formatted_str,
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
@property
def _chain_type(self) -> str:
return "llm_knowledge_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMKnowledgeChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def search_knowledgebase_once(query: str):
model = model_container.MODEL
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_knowledge.run(query)
return ans
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="The query to be searched")
if __name__ == "__main__":
result = search_knowledgebase_once("大数据的男女比例")
print(result)

View File

@ -1,32 +0,0 @@
from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import json
import asyncio
from server.agent import model_container
async def search_knowledge_base_iter(database: str, query: str) -> str:
response = await knowledge_base_chat(query=query,
knowledge_base_name=database,
model_name=model_container.MODEL.model_name,
temperature=0.01,
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="knowledge_base_chat",
score_threshold=SCORE_THRESHOLD,
stream=False)
contents = ""
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents = data["answer"]
docs = data["docs"]
return contents
def search_knowledgebase_simple(query: str):
return asyncio.run(search_knowledge_base_iter(query))
if __name__ == "__main__":
result = search_knowledgebase_simple("大数据男女比例")
print("答案:",result)

View File

@ -0,0 +1,8 @@
from .search_local_knowledgebase import search_local_knowledgebase, SearchKnowledgeInput
from .calculate import calculate, CalculatorInput
from .weather_check import weather_check, WeatherInput
from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput
from .search_youtube import search_youtube, YoutubeInput
from .arxiv import arxiv, ArxivInput

View File

@ -0,0 +1,23 @@
from pydantic import BaseModel, Field
def calculate(a: float, b: float, operator: str) -> float:
if operator == "+":
return a + b
elif operator == "-":
return a - b
elif operator == "*":
return a * b
elif operator == "/":
if b != 0:
return a / b
else:
return float('inf') # 防止除以零
elif operator == "^":
return a ** b
else:
raise ValueError("Unsupported operator")
class CalculatorInput(BaseModel):
a: float = Field(description="first number")
b: float = Field(description="second number")
operator: str = Field(description="operator to use (e.g., +, -, *, /, ^)")

View File

@ -0,0 +1,99 @@
from pydantic import BaseModel, Field
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import TOOL_CONFIG
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Dict
from langchain.docstore.document import Document
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify
def bing_search(text, config):
search = BingSearchAPIWrapper(bing_subscription_key=config["bing_key"],
bing_search_url=config["bing_search_url"])
return search.results(text, config["result_len"])
def duckduckgo_search(text, config):
search = DuckDuckGoSearchAPIWrapper()
return search.results(text, config["result_len"])
def metaphor_search(
text: str,
config: dict,
) -> List[Dict]:
from metaphor_python import Metaphor
client = Metaphor(config["metaphor_api_key"])
search = client.search(text, num_results=config["result_len"], use_autoprompt=True)
contents = search.get_contents().contents
for x in contents:
x.extract = markdownify(x.extract)
if config["split_result"]:
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"])
splitted_docs = text_splitter.split_documents(docs)
if len(splitted_docs) > config["result_len"]:
normal = NormalizedLevenshtein()
for x in splitted_docs:
x.metadata["score"] = normal.similarity(text, x.page_content)
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
splitted_docs = splitted_docs[:config["result_len"]]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
else:
docs = [{"snippet": x.extract,
"link": x.url,
"title": x.title}
for x in contents]
return docs
SEARCH_ENGINES = {"bing": bing_search,
"duckduckgo": duckduckgo_search,
"metaphor": metaphor_search,
}
def search_result2docs(search_results):
docs = []
for result in search_results:
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
metadata={"source": result["link"] if "link" in result.keys() else "",
"filename": result["title"] if "title" in result.keys() else ""})
docs.append(doc)
return docs
def search_engine(query: str,
config: dict):
search_engine_use = SEARCH_ENGINES[config["search_engine_name"]]
results = search_engine_use(text=query,
config=config["search_engine_config"][
config["search_engine_name"]])
docs = search_result2docs(results)
context = ""
docs = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs)
]
for doc in docs:
context += doc + "\n"
return context
def search_internet(query: str):
tool_config = TOOL_CONFIG["search_internet"]
return search_engine(query=query, config=tool_config)
class SearchInternetInput(BaseModel):
query: str = Field(description="Query for Internet search")

View File

@ -0,0 +1,40 @@
from urllib.parse import urlencode
from pydantic import BaseModel, Field
from server.knowledge_base.kb_doc_api import search_docs
from configs import TOOL_CONFIG
def search_knowledgebase(query: str, database: str, config: dict):
docs = search_docs(
query=query,
knowledge_base_name=database,
top_k=config["top_k"],
score_threshold=config["score_threshold"])
context = ""
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": database, "file_name": filename})
url = f"download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0:
context= "没有找到相关文档,请更换关键词重试"
else:
for doc in source_documents:
context += doc + "\n"
return context
class SearchKnowledgeInput(BaseModel):
database: str = Field(description="Database for Knowledge Search")
query: str = Field(description="Query for Knowledge Search")
def search_local_knowledgebase(database: str, query: str):
tool_config = TOOL_CONFIG["search_local_knowledgebase"]
return search_knowledgebase(query=query, database=database, config=tool_config)

View File

@ -6,4 +6,4 @@ def search_youtube(query: str):
return tool.run(tool_input=query)
class YoutubeInput(BaseModel):
location: str = Field(description="Query for Videos search")
query: str = Field(description="Query for Videos search")

View File

@ -6,4 +6,4 @@ def shell(query: str):
return tool.run(tool_input=query)
class ShellInput(BaseModel):
query: str = Field(description="一个能在Linux命令行运行的Shell命令")
query: str = Field(description="The command to execute")

View File

@ -0,0 +1,59 @@
from langchain_core.tools import StructuredTool
from server.agent.tools_factory import *
from configs import KB_INFO
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool."
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str)
all_tools = [
StructuredTool.from_function(
func=calculate,
name="calculate",
description="Useful for when you need to answer questions about simple calculations",
args_schema=CalculatorInput,
),
StructuredTool.from_function(
func=arxiv,
name="arxiv",
description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
args_schema=ArxivInput,
),
StructuredTool.from_function(
func=shell,
name="shell",
description="Use Shell to execute Linux commands",
args_schema=ShellInput,
),
StructuredTool.from_function(
func=wolfram,
name="wolfram",
description="Useful for when you need to calculate difficult formulas",
args_schema=WolframInput,
),
StructuredTool.from_function(
func=search_youtube,
name="search_youtube",
description="use this tools_factory to search youtube videos",
args_schema=YoutubeInput,
),
StructuredTool.from_function(
func=weather_check,
name="weather_check",
description="Use this tool to check the weather at a specific location",
args_schema=WeatherInput,
),
StructuredTool.from_function(
func=search_internet,
name="search_internet",
description="Use this tool to use bing search engine to search the internet and get information",
args_schema=SearchInternetInput,
),
StructuredTool.from_function(
func=search_local_knowledgebase,
name="search_local_knowledgebase",
description=template_knowledge,
args_schema=SearchKnowledgeInput,
),
]

View File

@ -1,10 +1,8 @@
"""
简单的单参数输入工具实现用于查询现在天气的情况
简单的单参数输入工具实现用于查询现在天气的情况
"""
from pydantic import BaseModel, Field
import requests
from configs.kb_config import SENIVERSE_API_KEY
def weather(location: str, api_key: str):
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
@ -21,9 +19,7 @@ def weather(location: str, api_key: str):
f"Failed to retrieve weather: {response.status_code}")
def weathercheck(location: str):
return weather(location, SENIVERSE_API_KEY)
def weather_check(location: str):
return weather(location, "S8vrB4U_-c5mvAMiK")
class WeatherInput(BaseModel):
location: str = Field(description="City name,include city and county")
location: str = Field(description="City name,include city and county,like '厦门'")

View File

@ -8,4 +8,4 @@ def wolfram(query: str):
return ans
class WolframInput(BaseModel):
location: str = Field(description="需要运算的具体问题")
formula: str = Field(description="The formula to be calculated")

View File

@ -1,55 +0,0 @@
from langchain.tools import Tool
from server.agent.tools import *
tools = [
Tool.from_function(
func=calculate,
name="calculate",
description="Useful for when you need to answer questions about simple calculations",
args_schema=CalculatorInput,
),
Tool.from_function(
func=arxiv,
name="arxiv",
description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
args_schema=ArxivInput,
),
Tool.from_function(
func=weathercheck,
name="weather_check",
description="",
args_schema=WeatherInput,
),
Tool.from_function(
func=shell,
name="shell",
description="Use Shell to execute Linux commands",
args_schema=ShellInput,
),
Tool.from_function(
func=search_knowledgebase_complex,
name="search_knowledgebase_complex",
description="Use Use this tool to search local knowledgebase and get information",
args_schema=KnowledgeSearchInput,
),
Tool.from_function(
func=search_internet,
name="search_internet",
description="Use this tool to use bing search engine to search the internet",
args_schema=SearchInternetInput,
),
Tool.from_function(
func=wolfram,
name="Wolfram",
description="Useful for when you need to calculate difficult formulas",
args_schema=WolframInput,
),
Tool.from_function(
func=search_youtube,
name="search_youtube",
description="use this tools to search youtube videos",
args_schema=YoutubeInput,
),
]
tool_names = [tool.name for tool in tools]

View File

@ -13,13 +13,12 @@ from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat.chat import chat
from server.chat.search_engine_chat import search_engine_chat
from server.chat.completion import completion
from server.chat.feedback import chat_feedback
from server.embeddings_api import embed_texts_endpoint
from server.llm_api import (list_running_models, list_config_models,
change_llm_model, stop_llm_model,
get_model_config, list_search_engines)
get_model_config)
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
get_server_configs, get_prompt_template)
from typing import List, Literal
@ -63,11 +62,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
summary="与llm模型对话(通过LLMChain)",
)(chat)
app.post("/chat/search_engine_chat",
tags=["Chat"],
summary="与搜索引擎对话",
)(search_engine_chat)
app.post("/chat/feedback",
tags=["Chat"],
summary="返回llm模型对话评分",
@ -110,16 +104,12 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
summary="获取服务器原始配置信息",
)(get_server_configs)
app.post("/server/list_search_engines",
tags=["Server State"],
summary="获取服务器支持的搜索引擎",
)(list_search_engines)
@app.post("/server/get_prompt_template",
tags=["Server State"],
summary="获取服务区配置的 prompt 模板")
def get_server_prompt_template(
type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型可选值llm_chatknowledge_base_chatsearch_engine_chatagent_chat"),
type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型可选值llm_chatknowledge_base_chat"),
name: str = Body("default", description="模板名称"),
) -> str:
return get_prompt_template(type=type, name=name)
@ -139,7 +129,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
def mount_knowledge_routes(app: FastAPI):
from server.chat.knowledge_base_chat import knowledge_base_chat
from server.chat.file_chat import upload_temp_docs, file_chat
from server.chat.agent_chat import agent_chat
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
@ -155,11 +144,6 @@ def mount_knowledge_routes(app: FastAPI):
summary="文件对话"
)(file_chat)
app.post("/chat/agent_chat",
tags=["Chat"],
summary="与agent对话")(agent_chat)
# Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,

View File

@ -1,13 +1,11 @@
from __future__ import annotations
from uuid import UUID
from langchain.callbacks import AsyncIteratorCallbackHandler
import json
import asyncio
from typing import Any, Dict, List, Optional
from langchain.schema import AgentFinish, AgentAction
from langchain.schema.output import LLMResult
import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import AsyncCallbackHandler
def dumps(obj: Dict) -> str:
return json.dumps(obj, ensure_ascii=False)
@ -23,7 +21,7 @@ class Status:
tool_finish: int = 7
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
def __init__(self):
super().__init__()
self.queue = asyncio.Queue()
@ -31,40 +29,29 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.cur_tool = {}
self.out = True
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None,
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
# 对于截断不能自理的大模型,我来帮他截断
stop_words = ["Observation:", "Thought","\"","", "\n","\t"]
for stop_word in stop_words:
index = input_str.find(stop_word)
if index != -1:
input_str = input_str[:index]
break
self.cur_tool = {
"tool_name": serialized["name"],
"input_str": input_str,
"output_str": "",
"status": Status.agent_action,
"run_id": run_id.hex,
"llm_token": "",
"final_answer": "",
"error": "",
}
# print("\nInput Str:",self.cur_tool["input_str"])
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
tags: List[str] | None = None, **kwargs: Any) -> None:
self.out = True ## 重置输出
self.cur_tool.update(
status=Status.tool_finish,
output_str=output.replace("Answer:", ""),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
print("on_tool_start")
async def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
print("on_tool_end")
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
self.cur_tool.update(
@ -73,23 +60,6 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
)
self.queue.put_nowait(dumps(self.cur_tool))
# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# if "Action" in token: ## 减少重复输出
# before_action = token.split("Action")[0]
# self.cur_tool.update(
# status=Status.running,
# llm_token=before_action + "\n",
# )
# self.queue.put_nowait(dumps(self.cur_tool))
#
# self.out = False
#
# if token and self.out:
# self.cur_tool.update(
# status=Status.running,
# llm_token=token,
# )
# self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["Action", "<|observation|>"]
for stoken in special_tokens:
@ -103,7 +73,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.out = False
break
if token and self.out:
if token is not None and token != "" and self.out:
self.cur_tool.update(
status=Status.running,
llm_token=token,
@ -116,16 +86,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self.cur_tool.update(
status=Status.start,
@ -136,8 +107,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.complete,
llm_token="\n",
llm_token="",
)
self.out = True
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
@ -147,15 +119,44 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
self.cur_tool.update(
status=Status.agent_action,
tool_name=action.tool,
tool_input=action.tool_input,
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_agent_finish(
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
# 返回最终答案
self.cur_tool.update(
status=Status.agent_finish,
final_answer=finish.return_values["output"],
agent_finish=finish.return_values["output"],
)
self.queue.put_nowait(dumps(self.cur_tool))
self.cur_tool = {}
async def aiter(self) -> AsyncIterator[str]:
while not self.queue.empty() or not self.done.is_set():
done, other = await asyncio.wait(
[
asyncio.ensure_future(self.queue.get()),
asyncio.ensure_future(self.done.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)
if other:
other.pop().cancel()
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
if token_or_done is True:
break
yield token_or_done

View File

@ -1,178 +0,0 @@
import json
import asyncio
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from typing import AsyncIterable, Optional, List
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.knowledge_base.kb_service.base import get_kb_details
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from server.chat.utils import History
from server.agent import model_container
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user", "content": "请使用知识库工具查询今天北京天气"},
{"role": "assistant",
"content": "使用天气查询工具查询到今天北京多云10-14摄氏度东北风2级易感冒"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
history = [History.from_data(h) for h in history]
async def agent_chat_iterator(
query: str,
history: Optional[List[History]],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = CustomAsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
kb_list = {x["kb_name"]: x for x in get_kb_details()}
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
if Agent_MODEL:
model_agent = get_ChatOpenAI(
model_name=Agent_MODEL,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
model_container.MODEL = model_agent
else:
model_container.MODEL = model
prompt_template = get_prompt_template("agent_chat", prompt_name)
prompt_template_agent = CustomPromptTemplate(
template=prompt_template,
tools=tools,
input_variables=["input", "intermediate_steps", "history"]
)
output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history:
if message.role == 'user':
memory.chat_memory.add_user_message(message.content)
else:
memory.chat_memory.add_ai_message(message.content)
if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name:
agent_executor = initialize_glm3_agent(
llm=model,
tools=tools,
callback_manager=None,
prompt=prompt_template,
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
verbose=True,
)
else:
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:", "Observation"],
allowed_tools=tool_names,
)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
while True:
try:
task = asyncio.create_task(wrap_done(
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
callback.done))
break
except:
pass
if stream:
async for chunk in callback.aiter():
tools_use = []
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
elif data["status"] == Status.error:
tools_use.append("\n```\n")
tools_use.append("工具名称: " + data["tool_name"])
tools_use.append("工具状态: " + "调用失败")
tools_use.append("错误信息: " + data["error"])
tools_use.append("重新开始尝试")
tools_use.append("\n```\n")
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
elif data["status"] == Status.tool_finish:
tools_use.append("\n```\n")
tools_use.append("工具名称: " + data["tool_name"])
tools_use.append("工具状态: " + "调用成功")
tools_use.append("工具输入: " + data["input_str"])
tools_use.append("工具输出: " + data["output_str"])
tools_use.append("\n```\n")
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
elif data["status"] == Status.agent_finish:
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
else:
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
else:
answer = ""
final_answer = ""
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
if data["status"] == Status.error:
answer += "\n```\n"
answer += "工具名称: " + data["tool_name"] + "\n"
answer += "工具状态: " + "调用失败" + "\n"
answer += "错误信息: " + data["error"] + "\n"
answer += "\n```\n"
if data["status"] == Status.tool_finish:
answer += "\n```\n"
answer += "工具名称: " + data["tool_name"] + "\n"
answer += "工具状态: " + "调用成功" + "\n"
answer += "工具输入: " + data["input_str"] + "\n"
answer += "工具输出: " + data["output_str"] + "\n"
answer += "\n```\n"
if data["status"] == Status.agent_finish:
final_answer = data["final_answer"]
else:
answer += data["llm_token"]
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
await task
return EventSourceResponse(agent_chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)

View File

@ -1,20 +1,43 @@
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE
from fastapi.responses import StreamingResponse
from langchain.agents import initialize_agent, AgentType
from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableBranch
from server.agent.agent_factory import initialize_glm3_agent
from server.agent.tools_factory.tools_registry import all_tools
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
from typing import AsyncIterable, Dict
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional, Union
from typing import List, Union
from server.chat.utils import History
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
def create_models_from_config(configs: dict = {}, callbacks: list = []):
models = {}
prompts = {}
for model_type, model_configs in configs.items():
for model_name, params in model_configs.items():
callback = callbacks if params.get('callbacks', False) else None
model_instance = get_ChatOpenAI(
model_name=model_name,
temperature=params.get('temperature', 0.5),
max_tokens=params.get('max_tokens', 1000),
callbacks=callback
)
models[model_type] = model_instance
prompt_name = params.get('prompt_name', 'default')
prompt_template = get_prompt_template(type=model_type, name=prompt_name)
prompts[model_type] = prompt_template
return models, prompts
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
@ -28,76 +51,106 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
model_config: Dict = Body({}, description="LLM 模型配置。"),
tool_config: Dict = Body({}, description="工具配置"),
):
async def chat_iterator() -> AsyncIterable[str]:
nonlocal history, max_tokens
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
nonlocal history
memory = None
message_id = None
chat_prompt = None
# 负责保存llm response到message db
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=query)
callbacks.append(conversation_callback)
callback = CustomAsyncIteratorCallbackHandler()
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
if conversation_id:
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
tools = [tool for tool in all_tools if tool.name in tool_config]
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
if history: # 优先使用前端传入的历史消息
if history:
history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
prompt = get_prompt_template("llm_chat", "with_history")
chat_prompt = PromptTemplate.from_template(prompt)
# 根据conversation_id 获取message 列表进而拼凑 memory
memory = ConversationBufferDBMemory(conversation_id=conversation_id,
llm=model,
elif conversation_id and history_len > 0:
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
message_limit=history_len)
else:
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"input": query}),
callback.done),
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
classifier_chain = (
PromptTemplate.from_template(prompts["preprocess_model"])
| models["preprocess_model"]
| StrOutputParser()
)
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps(
{"text": token, "message_id": message_id},
ensure_ascii=False)
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
callback_manager=BaseCallbackManager(handlers=callbacks),
verbose=True,
)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps(
{"text": answer, "message_id": message_id},
ensure_ascii=False)
agent_executor = initialize_agent(
llm=models["action_model"],
tools=tools,
callbacks=callbacks,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
memory=memory,
verbose=True,
)
branch = RunnableBranch(
(lambda x: "1" in x["topic"].lower(), agent_executor),
chain
)
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
if stream:
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == Status.start:
continue
elif data["status"] == Status.agent_action:
tool_info = {
"tool_name": data["tool_name"],
"tool_input": data["tool_input"]
}
yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False)
elif data["status"] == Status.agent_finish:
yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id},
ensure_ascii=False)
else:
yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False)
else:
text = ""
agent_finish = ""
tool_info = None
async for chunk in callback.aiter():
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.agent_action:
tool_info = {
"tool_name": data["tool_name"],
"tool_input": data["tool_input"]
}
if data["status"] == Status.agent_finish:
agent_finish = data["agent_finish"]
else:
text += data["llm_token"]
if tool_info:
yield json.dumps(
{"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id},
ensure_ascii=False)
else:
yield json.dumps(
{"text": text, "message_id": message_id},
ensure_ascii=False)
await task
return EventSourceResponse(chat_iterator())

View File

@ -1,6 +1,6 @@
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE
from fastapi.responses import StreamingResponse
from configs import LLM_MODEL_CONFIG
from server.utils import wrap_done, get_OpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
@ -14,8 +14,8 @@ from server.utils import get_prompt_template
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
stream: bool = Body(False, description="流式输出"),
echo: bool = Body(False, description="除了输出之外,还回显输入"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default",
@ -24,7 +24,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
#todo 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str,
model_name: str = LLM_MODELS[0],
model_name: str = None,
prompt_name: str = prompt_name,
echo: bool = echo,
) -> AsyncIterable[str]:

View File

@ -1,7 +1,6 @@
from fastapi import Body, File, Form, UploadFile
from sse_starlette.sse import EventSourceResponse
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from fastapi.responses import StreamingResponse
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import (wrap_done, get_ChatOpenAI,
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
@ -15,8 +14,6 @@ from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.knowledge_base.utils import KnowledgeFile
import json
import os
from pathlib import Path
def _parse_files_in_thread(
files: List[UploadFile],
@ -102,8 +99,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):

View File

@ -1,147 +0,0 @@
from fastapi import Body, Request
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
SCORE_THRESHOLD,
description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右",
ge=0,
le=2
),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
None,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(
query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = await run_in_threadpool(search_docs,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
# 加入reranker
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=top_k,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)
print(docs)
docs = reranker_model.compress_documents(documents=docs,
query=query)
print("---------after rerank------------------")
print(docs)
context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = []
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))

View File

@ -1,208 +0,0 @@
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from sse_starlette import EventSourceResponse
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from server.chat.utils import History
from typing import AsyncIterable
import asyncio
import json
from typing import List, Optional, Dict
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
"title": "env info is not found",
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
bing_search_url=BING_SEARCH_URL)
return search.results(text, result_len)
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
search = DuckDuckGoSearchAPIWrapper()
return search.results(text, result_len)
def metaphor_search(
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]:
from metaphor_python import Metaphor
if not METAPHOR_API_KEY:
return []
client = Metaphor(METAPHOR_API_KEY)
search = client.search(text, num_results=result_len, use_autoprompt=True)
contents = search.get_contents().contents
for x in contents:
x.extract = markdownify(x.extract)
# metaphor 返回的内容都是长文本,需要分词再检索
if split_result:
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len:
normal = NormalizedLevenshtein()
for x in splitted_docs:
x.metadata["score"] = normal.similarity(text, x.page_content)
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
splitted_docs = splitted_docs[:result_len]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
else:
docs = [{"snippet": x.extract,
"link": x.url,
"title": x.title}
for x in contents]
return docs
SEARCH_ENGINES = {"bing": bing_search,
"duckduckgo": duckduckgo_search,
"metaphor": metaphor_search,
}
def search_result2docs(search_results):
docs = []
for result in search_results:
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
metadata={"source": result["link"] if "link" in result.keys() else "",
"filename": result["title"] if "title" in result.keys() else ""})
docs.append(doc)
return docs
async def lookup_search_engine(
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
):
search_engine = SEARCH_ENGINES[search_engine_name]
results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result)
docs = search_result2docs(results)
return docs
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None,
description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False,
description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`")
history = [History.from_data(h) for h in history]
async def search_engine_chat_iterator(query: str,
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result)
context = "\n".join([doc.page_content for doc in docs])
prompt_template = get_prompt_template("search_engine_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(docs)
]
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task
return EventSourceResponse(search_engine_chat_iterator(query=query,
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)

View File

@ -9,7 +9,6 @@ class ConversationModel(Base):
__tablename__ = 'conversation'
id = Column(String(32), primary_key=True, comment='对话框ID')
name = Column(String(50), comment='对话框名称')
# chat/agent_chat等
chat_type = Column(String(50), comment='聊天类型')
create_time = Column(DateTime, default=func.now(), comment='创建时间')

View File

@ -10,7 +10,6 @@ class MessageModel(Base):
__tablename__ = 'message'
id = Column(String(32), primary_key=True, comment='聊天记录ID')
conversation_id = Column(String(32), default=None, index=True, comment='对话框ID')
# chat/agent_chat等
chat_type = Column(String(50), comment='聊天类型')
query = Column(String(4096), comment='用户问题')
response = Column(String(4096), comment='模型回答')

View File

@ -10,7 +10,6 @@ from typing import List, Optional
from server.knowledge_base.kb_summary.base import KBSummaryService
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
from configs import LLM_MODELS, TEMPERATURE
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
def recreate_summary_vector_store(
@ -19,8 +18,8 @@ def recreate_summary_vector_store(
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
@ -100,8 +99,8 @@ def summary_file_to_vector_store(
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
@ -172,8 +171,8 @@ def summary_doc_ids_to_vector_store(
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
model_name: str = Body(None, description="LLM 模型名称。"),
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
) -> BaseResponse:
"""

View File

@ -7,7 +7,6 @@ from configs import (
logger,
log_verbose,
text_splitter_dict,
LLM_MODELS,
TEXT_SPLITTER_NAME,
)
import importlib
@ -187,10 +186,10 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
def make_text_splitter(
splitter_name: str = TEXT_SPLITTER_NAME,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
llm_model: str = LLM_MODELS[0],
splitter_name,
chunk_size,
chunk_overlap,
llm_model,
):
"""
根据参数获取特定的分词器

View File

@ -1,5 +1,5 @@
from fastapi import Body
from configs import logger, log_verbose, LLM_MODELS, HTTPX_DEFAULT_TIMEOUT
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
from typing import List
@ -62,7 +62,7 @@ def get_model_config(
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODELS[0]]),
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
@ -86,8 +86,8 @@ def stop_llm_model(
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODELS[0]]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODELS[0]]),
model_name: str = Body(..., description="当前运行模型"),
new_model_name: str = Body(..., description="要切换的新模型"),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
@ -108,9 +108,3 @@ def change_llm_model(
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
def list_search_engines() -> BaseResponse:
from server.chat.search_engine_chat import SEARCH_ENGINES
return BaseResponse(data=list(SEARCH_ENGINES))

View File

@ -1,5 +1,5 @@
from fastchat.conversation import Conversation
from configs import LOG_PATH, TEMPERATURE
from configs import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.base_model_worker import BaseModelWorker
@ -63,7 +63,7 @@ class ApiModelParams(ApiConfigParams):
deployment_name: Optional[str] = None # for azure
resource_name: Optional[str] = None # for azure
temperature: float = TEMPERATURE
temperature: float = 0.9
max_tokens: Optional[int] = None
top_p: Optional[float] = 1.0

View File

@ -1,11 +1,6 @@
import json
import sys
from fastchat.conversation import Conversation
from configs import TEMPERATURE
from http import HTTPStatus
from typing import List, Literal, Dict
from fastchat import conversation as conv
from server.model_workers.base import *
from server.model_workers.base import ApiEmbeddingsParams

View File

@ -4,7 +4,7 @@ from typing import List
from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os
@ -402,8 +402,8 @@ def fschat_controller_address() -> str:
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
if model := get_model_worker_config(model_name):
def fschat_model_worker_address(model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model']))) -> str:
if model := get_model_worker_config(model_name): # TODO: depends fastchat
host = model["host"]
if host == "0.0.0.0":
host = "127.0.0.1"
@ -443,7 +443,7 @@ def webui_address() -> str:
def get_prompt_template(type: str, name: str) -> Optional[str]:
'''
从prompt_config中加载模板内容
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种如果有新功能应该进行加入
type: "llm_chat","knowledge_base_chat","search_engine_chat"的其中一种如果有新功能应该进行加入
'''
from configs import prompt_config
@ -617,26 +617,6 @@ def get_server_configs() -> Dict:
'''
获取configs中的原始配置项供前端使用
'''
from configs.kb_config import (
DEFAULT_KNOWLEDGE_BASE,
DEFAULT_SEARCH_ENGINE,
DEFAULT_VS_TYPE,
CHUNK_SIZE,
OVERLAP_SIZE,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
ZH_TITLE_ENHANCE,
text_splitter_dict,
TEXT_SPLITTER_NAME,
)
from configs.model_config import (
LLM_MODELS,
HISTORY_LEN,
TEMPERATURE,
)
from configs.prompt_config import PROMPT_TEMPLATES
_custom = {
"controller_address": fschat_controller_address(),
"openai_api_address": fschat_openai_api_address(),

View File

@ -8,6 +8,7 @@ from datetime import datetime
from pprint import pprint
from langchain_core._api import deprecated
# 设置numexpr最大线程数默认为CPU核心数
try:
import numexpr
@ -21,7 +22,7 @@ from configs import (
LOG_PATH,
log_verbose,
logger,
LLM_MODELS,
LLM_MODEL_CONFIG,
EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
FSCHAT_CONTROLLER,
@ -32,13 +33,21 @@ from configs import (
HTTPX_DEFAULT_TIMEOUT,
)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, get_httpx_client, get_model_worker_config,
fschat_openai_api_address, get_httpx_client,
get_model_worker_config,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
from server.knowledge_base.migrate import create_tables
import argparse
from typing import List, Dict
from configs import VERSION
all_model_names = set()
for model_category in LLM_MODEL_CONFIG.values():
for model_name in model_category.keys():
if model_name not in all_model_names:
all_model_names.add(model_name)
all_model_names_list = list(all_model_names)
@deprecated(
since="0.3.0",
@ -109,9 +118,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
args.tokenizer = args.model_path
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer_mode = 'auto'
args.trust_remote_code = True
args.download_dir = None
@ -130,7 +139,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.conv_template = None
args.limit_worker_concurrency = 5
args.no_register = False
args.num_gpus = 1 # vllm worker的切分是tensor并行这里填写显卡的数量
args.num_gpus = 4 # vllm worker的切分是tensor并行这里填写显卡的数量
args.engine_use_ray = False
args.disable_log_requests = False
@ -365,7 +374,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
def run_model_worker(
model_name: str = LLM_MODELS[0],
model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model'])),
controller_address: str = "",
log_level: str = "INFO",
q: mp.Queue = None,
@ -502,7 +511,7 @@ def parse_args() -> argparse.ArgumentParser:
"--model-worker",
action="store_true",
help="run fastchat's model_worker server with specified model name. "
"specify --model-name if not using default LLM_MODELS",
"specify --model-name if not using default llm models",
dest="model_worker",
)
parser.add_argument(
@ -510,7 +519,7 @@ def parse_args() -> argparse.ArgumentParser:
"--model-name",
type=str,
nargs="+",
default=LLM_MODELS,
default=all_model_names_list,
help="specify model name for model worker. "
"add addition names with space seperated to start multiple model workers.",
dest="model_name",
@ -574,7 +583,7 @@ def dump_server_info(after_start=False, args=None):
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
models = LLM_MODELS
models = list(LLM_MODEL_CONFIG['llm_model'].keys())
if args and args.model_name:
models = args.model_name
@ -769,17 +778,17 @@ async def start_main_server():
if p := processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
api_started.wait()
api_started.wait() # 等待api.py启动完成
if p := processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
webui_started.wait()
webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args)
while True:
cmd = queue.get()
cmd = queue.get() # 收到切换模型的消息
e = manager.Event()
if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd
@ -877,20 +886,5 @@ if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1"
# model = "chatglm3-6b"
# # create a chat completion
# completion = openai.ChatCompletion.create(
# model=model,
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
# )
# # print the completion
# print(completion.choices[0].message.content)

View File

@ -29,15 +29,7 @@ def test_server_configs():
assert len(configs) > 0
def test_list_search_engines():
engines = api.list_search_engines()
pprint(engines)
assert isinstance(engines, list)
assert len(engines) > 0
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
@pytest.mark.parametrize("type", ["llm_chat"])
def test_get_prompt_template(type):
print(f"prompt template for: {type}")
template = api.get_prompt_template(type=type)

View File

@ -85,29 +85,3 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
assert "docs" in data and len(data["docs"]) > 0
assert response.status_code == 200
def test_search_engine_chat(api="/chat/search_engine_chat"):
global data
data["query"] = "室温超导最新进展是什么样?"
url = f"{api_base_url}{api}"
for se in ["bing", "duckduckgo"]:
data["search_engine_name"] = se
dump_input(data, api + f" by {se}")
response = requests.post(url, json=data, stream=True)
if se == "bing" and not BING_SUBSCRIPTION_KEY:
data = response.json()
assert data["code"] == 404
assert data["msg"] == f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`"
print("\n")
print("=" * 30 + api + f" by {se} output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line[6:])
if "answer" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
assert response.status_code == 200

View File

@ -56,15 +56,4 @@ def test_embeddings(worker):
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
print("向量长度:", len(embeddings[0]))
# @pytest.mark.parametrize("worker", workers)
# def test_completion(worker):
# params = ApiCompletionParams(prompt="五十六个民族")
# print(f"\completion with {worker} \n")
# worker_class = get_model_worker_config(worker)["worker_class"]
# resp = worker_class().do_completion(params)
# pprint(resp)
print("向量长度:", len(embeddings[0]))

View File

@ -6,8 +6,7 @@ from datetime import datetime
import os
import re
import time
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, LLM_MODELS,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
from server.knowledge_base.utils import LOADER_DICT
import uuid
from typing import List, Dict
@ -55,6 +54,7 @@ def parse_command(text: str, modal: Modal) -> bool:
/new {session_name}如果未提供名称默认为会话X
/del {session_name}如果未提供名称在会话数量>1的情况下删除当前会话
/clear {session_name}如果未提供名称默认清除当前会话
/stop {session_name}如果未提供名称默认停止当前会话
/help查看命令帮助
返回值输入的是命令返回True否则返回False
'''
@ -117,36 +117,38 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.write("\n\n".join(cmds))
with st.sidebar:
# 多会话
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
tools = list(TOOL_CONFIG.keys())
selected_tool_configs = {}
with st.expander("工具栏"):
for tool in tools:
is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool)
if is_selected:
selected_tool_configs[tool] = TOOL_CONFIG[tool]
if st.session_state.get("cur_conv_name") in conv_names:
index = conv_names.index(st.session_state.get("cur_conv_name"))
conversation_name = st.selectbox("当前会话:", conv_names, index=index)
conversation_name = st.selectbox("当前会话", conv_names, index=index)
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
def on_mode_change():
mode = st.session_state.dialogue_mode
text = f"已切换到 {mode} 模式。"
if mode == "知识库问答":
cur_kb = st.session_state.get("selected_kb")
if cur_kb:
text = f"{text} 当前知识库: `{cur_kb}`。"
st.toast(text)
# def on_mode_change():
# mode = st.session_state.dialogue_mode
# text = f"已切换到 {mode} 模式。"
# st.toast(text)
dialogue_modes = ["LLM 对话",
"知识库问答",
"文件对话",
"搜索引擎问答",
"自定义Agent问答",
]
dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes,
index=0,
on_change=on_mode_change,
key="dialogue_mode",
)
# dialogue_modes = ["智能对话",
# "文件对话",
# ]
# dialogue_mode = st.selectbox("请选择对话模式:",
# dialogue_modes,
# index=0,
# on_change=on_mode_change,
# key="dialogue_mode",
# )
def on_llm_change():
if llm_model:
@ -164,7 +166,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
available_models = []
config_models = api.list_config_models()
if not is_lite:
for k, v in config_models.get("local", {}).items():
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
if (v.get("model_path_exists")
and k not in running_models):
available_models.append(k)
@ -177,103 +179,47 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
index = llm_models.index(cur_llm_model)
else:
index = 0
llm_model = st.selectbox("选择LLM模型",
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not is_lite
and not llm_model in config_models.get("online", {})
and not llm_model in config_models.get("langchain", {})
and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
prev_model = st.session_state.get("prev_llm_model")
r = api.change_llm_model(prev_model, llm_model)
if msg := check_error_msg(r):
st.error(msg)
elif msg := check_success_msg(r):
st.success(msg)
st.session_state["prev_llm_model"] = llm_model
index_prompt = {
"LLM 对话": "llm_chat",
"自定义Agent问答": "agent_chat",
"搜索引擎问答": "search_engine_chat",
"知识库问答": "knowledge_base_chat",
"文件对话": "knowledge_base_chat",
}
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
prompt_template_name = prompt_templates_kb_list[0]
if "prompt_template_select" not in st.session_state:
st.session_state.prompt_template_select = prompt_templates_kb_list[0]
def prompt_change():
text = f"已切换为 {prompt_template_name} 模板。"
st.toast(text)
# 传入后端的内容
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
prompt_template_select = st.selectbox(
"请选择Prompt模板",
prompt_templates_kb_list,
index=0,
on_change=prompt_change,
key="prompt_template_select",
)
prompt_template_name = st.session_state.prompt_template_select
temperature = st.slider("Temperature", 0.0, 2.0, TEMPERATURE, 0.05)
history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
for key in LLM_MODEL_CONFIG:
if key == 'llm_model':
continue
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
def on_kb_change():
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
if llm_model is not None:
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
if dialogue_mode == "知识库问答":
with st.expander("知识库配置", True):
kb_list = api.list_knowledge_bases()
index = 0
if DEFAULT_KNOWLEDGE_BASE in kb_list:
index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
selected_kb = st.selectbox(
"请选择知识库:",
kb_list,
index=index,
on_change=on_kb_change,
key="selected_kb",
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
elif dialogue_mode == "文件对话":
with st.expander("文件对话配置", True):
files = st.file_uploader("上传知识文件:",
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
if st.button("开始上传", disabled=len(files) == 0):
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines()
if DEFAULT_SEARCH_ENGINE in search_engine_list:
index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
else:
index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
with st.expander("搜索引擎配置", True):
search_engine = st.selectbox(
label="请选择搜索引擎",
options=search_engine_list,
index=index,
)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
print(model_config)
files = st.file_uploader("上传附件",
type=[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True)
# if dialogue_mode == "文件对话":
# with st.expander("文件对话配置", True):
# files = st.file_uploader("上传知识文件:",
# [i for ls in LOADER_DICT.values() for i in ls],
# accept_multiple_files=True,
# )
# kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
# score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
# if st.button("开始上传", disabled=len(files) == 0):
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 "
def on_feedback(
@ -297,140 +243,76 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
st.rerun()
else:
history = get_messages_history(history_len)
history = get_messages_history(
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
chat_box.user_say(prompt)
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
message_id = ""
r = api.chat_chat(prompt,
history=history,
conversation_id=conversation_id,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
break
text += t.get("text", "")
chat_box.update_msg(text)
message_id = t.get("message_id", "")
chat_box.ai_say("正在思考...")
text = ""
message_id = ""
element_index = 0
for d in api.chat_chat(query=prompt,
history=history,
model_config=model_config,
conversation_id=conversation_id,
tool_config=selected_tool_configs,
):
try:
d = json.loads(d)
except:
pass
message_id = d.get("message_id", "")
metadata = {
"message_id": message_id,
}
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
chat_box.show_feedback(**feedback_kwargs,
key=message_id,
on_submit=on_feedback,
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
elif dialogue_mode == "自定义Agent问答":
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
chat_box.ai_say([
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐请更换支持Agent的模型获得更好的体验</span>\n\n\n",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
])
else:
chat_box.ai_say([
f"正在思考...",
Markdown("...", in_expander=True, title="思考过程", state="complete"),
])
text = ""
ans = ""
for d in api.agent_chat(prompt,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature,
):
try:
d = json.loads(d)
except:
pass
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
if chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=1)
if chunk := d.get("final_answer"):
ans += chunk
chat_box.update_msg(ans, element_index=0)
if chunk := d.get("tools"):
text += "\n\n".join(d.get("tools", []))
chat_box.update_msg(text, element_index=1)
chat_box.update_msg(ans, element_index=0, streaming=False)
chat_box.update_msg(text, element_index=1, streaming=False)
elif dialogue_mode == "知识库问答":
chat_box.ai_say([
f"正在查询知识库 `{selected_kb}` ...",
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
])
text = ""
for d in api.knowledge_base_chat(prompt,
knowledge_base_name=selected_kb,
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
elif dialogue_mode == "文件对话":
if st.session_state["file_chat_id"] is None:
st.error("请先上传文件再进行对话")
st.stop()
chat_box.ai_say([
f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
])
text = ""
for d in api.file_chat(prompt,
knowledge_id=st.session_state["file_chat_id"],
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
elif dialogue_mode == "搜索引擎问答":
chat_box.ai_say([
f"正在执行 `{search_engine}` 搜索...",
Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
])
text = ""
for d in api.search_engine_chat(prompt,
search_engine_name=search_engine,
top_k=se_top_k,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature,
split_result=se_top_k > 1):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
if error_msg := check_error_msg(d):
st.error(error_msg)
if chunk := d.get("agent_action"):
chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete"))
element_index = 1
formatted_data = {
"action": chunk["tool_name"],
"action_input": chunk["tool_input"]
}
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
if chunk := d.get("text"):
text += chunk
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
if chunk := d.get("agent_finish"):
element_index = 0
text = chunk
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
chat_box.show_feedback(**feedback_kwargs,
key=message_id,
on_submit=on_feedback,
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
# elif dialogue_mode == "文件对话":
# if st.session_state["file_chat_id"] is None:
# st.error("请先上传文件再进行对话")
# st.stop()
# chat_box.ai_say([
# f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
# Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
# ])
# text = ""
# for d in api.file_chat(prompt,
# knowledge_id=st.session_state["file_chat_id"],
# top_k=kb_top_k,
# score_threshold=score_threshold,
# history=history,
# model=llm_model,
# prompt_name=prompt_template_name,
# temperature=temperature):
# if error_msg := check_error_msg(d):
# st.error(error_msg)
# elif chunk := d.get("answer"):
# text += chunk
# chat_box.update_msg(text, element_index=0)
# chat_box.update_msg(text, element_index=0, streaming=False)
# chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
if st.session_state.get("need_rerun"):
st.session_state["need_rerun"] = False
st.rerun()

View File

@ -3,18 +3,15 @@
from typing import *
from pathlib import Path
# 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
from configs import (
EMBEDDING_MODEL,
DEFAULT_VS_TYPE,
LLM_MODELS,
TEMPERATURE,
LLM_MODEL_CONFIG,
SCORE_THRESHOLD,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
HTTPX_DEFAULT_TIMEOUT,
logger, log_verbose,
)
@ -26,7 +23,6 @@ from io import BytesIO
from server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint
from langchain_core._api import deprecated
set_httpx_config()
@ -247,10 +243,6 @@ class ApiRequest:
response = self.post("/server/configs", **kwargs)
return self._get_response_value(response, as_json=True)
def list_search_engines(self, **kwargs) -> List:
response = self.post("/server/list_search_engines", **kwargs)
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
def get_prompt_template(
self,
type: str = "llm_chat",
@ -272,10 +264,8 @@ class ApiRequest:
history_len: int = -1,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
model_config: Dict = None,
tool_config: Dict = None,
**kwargs,
):
'''
@ -287,10 +277,8 @@ class ApiRequest:
"history_len": history_len,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
"model_config": model_config,
"tool_config": tool_config,
}
# print(f"received input message:")
@ -299,78 +287,6 @@ class ApiRequest:
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0")
def agent_chat(
self,
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/agent_chat 接口
'''
data = {
"query": query,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
}
# print(f"received input message:")
# pprint(data)
response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response, as_json=True)
def knowledge_base_chat(
self,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/knowledge_base_chat接口
'''
data = {
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
}
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/knowledge_base_chat",
json=data,
stream=True,
)
return self._httpx_stream2generator(response, as_json=True)
def upload_temp_docs(
self,
files: List[Union[str, Path, bytes]],
@ -416,8 +332,8 @@ class ApiRequest:
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
model: str = None,
temperature: float = 0.9,
max_tokens: int = None,
prompt_name: str = "default",
):
@ -444,50 +360,6 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def search_engine_chat(
self,
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
split_result: bool = False,
):
'''
对应api.py/chat/search_engine_chat接口
'''
data = {
"query": query,
"search_engine_name": search_engine_name,
"top_k": top_k,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_name": prompt_name,
"split_result": split_result,
}
# print(f"received input message:")
# pprint(data)
response = self.post(
"/chat/search_engine_chat",
json=data,
stream=True,
)
return self._httpx_stream2generator(response, as_json=True)
# 知识库相关操作
def list_knowledge_bases(
@ -769,7 +641,7 @@ class ApiRequest:
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
'''
从服务器上获取当前运行的LLM模型
local_first=True 优先返回运行中的本地模型否则优先按LLM_MODELS配置顺序返回
local_first=True 优先返回运行中的本地模型否则优先按 LLM_MODEL_CONFIG['llm_model']配置顺序返回
返回类型为model_name, is_local_model
'''
@ -779,7 +651,7 @@ class ApiRequest:
return "", False
model = ""
for m in LLM_MODELS:
for m in LLM_MODEL_CONFIG['llm_model']:
if m not in running_models:
continue
is_local = not running_models[m].get("online_api")
@ -789,7 +661,7 @@ class ApiRequest:
model = m
break
if not model: # LLM_MODELS中配置的模型都不在running_models里
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
@ -800,7 +672,7 @@ class ApiRequest:
return "", False
model = ""
for m in LLM_MODELS:
for m in LLM_MODEL_CONFIG['llm_model']:
if m not in running_models:
continue
is_local = not running_models[m].get("online_api")
@ -810,7 +682,7 @@ class ApiRequest:
model = m
break
if not model: # LLM_MODELS中配置的模型都不在running_models里
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
@ -852,15 +724,6 @@ class ApiRequest:
)
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def list_search_engines(self) -> List[str]:
'''
获取服务器支持的搜索引擎
'''
response = self.post(
"/server/list_search_engines",
)
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def stop_llm_model(
self,
model_name: str,