From 253168a18785717d8168a5a4b1ea417cb190e502 Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Tue, 5 Dec 2023 17:17:53 +0800
Subject: [PATCH] Dev (#2280)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* 修复Azure 不设置Max token的bug
* 重写agent
1. 修改Agent实现方式,支持多参数,仅剩 ChatGLM3-6b和 OpenAI GPT4 支持,剩余模型将在暂时缺席Agent功能
2. 删除agent_chat 集成到llm_chat中
3. 重写大部分工具,适应新Agent
* 更新架构
* 删除web_chat,自动融合
* 移除所有聊天,都变成Agent控制
* 更新配置文件
* 更新配置模板和提示词
* 更改参数选择bug
---
chains/llmchain_with_history.py | 22 --
configs/model_config.py.example | 302 +++++++++------
configs/prompt_config.py.example | 126 +++----
configs/server_config.py.example | 2 +-
docs/ES部署指南.md | 29 --
requirements.txt | 12 +-
requirements_api.txt | 12 +-
server/{chat => }/__init__.py | 0
server/agent/__init__.py | 4 -
server/agent/agent_factory/__init__.py | 1 +
.../glm3_agent.py} | 14 +-
server/agent/custom_template.py | 67 ----
server/agent/model_contain.py | 6 -
server/agent/tools/__init__.py | 11 -
server/agent/tools/calculate.py | 76 ----
server/agent/tools/search_internet.py | 37 --
.../tools/search_knowledgebase_complex.py | 287 --------------
.../agent/tools/search_knowledgebase_once.py | 234 ------------
.../tools/search_knowledgebase_simple.py | 32 --
server/agent/tools_factory/__init__.py | 8 +
.../agent/{tools => tools_factory}/arxiv.py | 0
server/agent/tools_factory/calculate.py | 23 ++
server/agent/tools_factory/search_internet.py | 99 +++++
.../search_local_knowledgebase.py | 40 ++
.../search_youtube.py | 2 +-
.../agent/{tools => tools_factory}/shell.py | 2 +-
server/agent/tools_factory/tools_registry.py | 59 +++
.../{tools => tools_factory}/weather_check.py | 12 +-
.../agent/{tools => tools_factory}/wolfram.py | 2 +-
server/agent/tools_select.py | 55 ---
server/api.py | 20 +-
.../agent_callback_handler.py} | 143 +++----
server/chat/agent_chat.py | 178 ---------
server/chat/chat.py | 175 ++++++---
server/chat/completion.py | 10 +-
server/chat/file_chat.py | 11 +-
server/chat/knowledge_base_chat.py | 147 --------
server/chat/search_engine_chat.py | 208 ----------
server/db/models/conversation_model.py | 1 -
server/db/models/message_model.py | 1 -
server/knowledge_base/kb_summary_api.py | 13 +-
server/knowledge_base/utils.py | 9 +-
server/llm_api.py | 14 +-
server/model_workers/base.py | 4 +-
server/model_workers/qwen.py | 5 -
server/utils.py | 28 +-
startup.py | 48 ++-
tests/api/test_server_state_api.py | 10 +-
tests/api/test_stream_chat_api.py | 26 --
tests/test_online_api.py | 13 +-
webui_pages/dialogue/dialogue.py | 354 ++++++------------
webui_pages/utils.py | 161 +-------
52 files changed, 862 insertions(+), 2293 deletions(-)
delete mode 100644 chains/llmchain_with_history.py
delete mode 100644 docs/ES部署指南.md
rename server/{chat => }/__init__.py (100%)
delete mode 100644 server/agent/__init__.py
create mode 100644 server/agent/agent_factory/__init__.py
rename server/agent/{custom_agent/ChatGLM3Agent.py => agent_factory/glm3_agent.py} (94%)
delete mode 100644 server/agent/custom_template.py
delete mode 100644 server/agent/model_contain.py
delete mode 100644 server/agent/tools/__init__.py
delete mode 100644 server/agent/tools/calculate.py
delete mode 100644 server/agent/tools/search_internet.py
delete mode 100644 server/agent/tools/search_knowledgebase_complex.py
delete mode 100644 server/agent/tools/search_knowledgebase_once.py
delete mode 100644 server/agent/tools/search_knowledgebase_simple.py
create mode 100644 server/agent/tools_factory/__init__.py
rename server/agent/{tools => tools_factory}/arxiv.py (100%)
create mode 100644 server/agent/tools_factory/calculate.py
create mode 100644 server/agent/tools_factory/search_internet.py
create mode 100644 server/agent/tools_factory/search_local_knowledgebase.py
rename server/agent/{tools => tools_factory}/search_youtube.py (80%)
rename server/agent/{tools => tools_factory}/shell.py (72%)
create mode 100644 server/agent/tools_factory/tools_registry.py
rename server/agent/{tools => tools_factory}/weather_check.py (76%)
rename server/agent/{tools => tools_factory}/wolfram.py (84%)
delete mode 100644 server/agent/tools_select.py
rename server/{agent/callbacks.py => callback_handler/agent_callback_handler.py} (53%)
delete mode 100644 server/chat/agent_chat.py
delete mode 100644 server/chat/knowledge_base_chat.py
delete mode 100644 server/chat/search_engine_chat.py
diff --git a/chains/llmchain_with_history.py b/chains/llmchain_with_history.py
deleted file mode 100644
index 2a845086..00000000
--- a/chains/llmchain_with_history.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from server.utils import get_ChatOpenAI
-from configs.model_config import LLM_MODELS, TEMPERATURE
-from langchain.chains import LLMChain
-from langchain.prompts.chat import (
- ChatPromptTemplate,
- HumanMessagePromptTemplate,
-)
-
-model = get_ChatOpenAI(model_name=LLM_MODELS[0], temperature=TEMPERATURE)
-
-
-human_prompt = "{input}"
-human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
-
-chat_prompt = ChatPromptTemplate.from_messages(
- [("human", "我们来玩成语接龙,我先来,生龙活虎"),
- ("ai", "虎头虎脑"),
- ("human", "{input}")])
-
-
-chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True)
-print(chain({"input": "恼羞成怒"}))
\ No newline at end of file
diff --git a/configs/model_config.py.example b/configs/model_config.py.example
index c42066d3..50314e56 100644
--- a/configs/model_config.py.example
+++ b/configs/model_config.py.example
@@ -1,14 +1,6 @@
import os
-
-# 可以指定一个绝对路径,统一存放所有的Embedding和LLM模型。
-# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录。
-# 如果模型目录名称和 MODEL_PATH 中的 key 或 value 相同,程序会自动检测加载,无需修改 MODEL_PATH 中的路径。
MODEL_ROOT_PATH = ""
-
-# 选用的 Embedding 名称
-EMBEDDING_MODEL = "bge-large-zh-v1.5"
-
-# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
+EMBEDDING_MODEL = "bge-large-zh-v1.5" # bge-large-zh
EMBEDDING_DEVICE = "auto"
# 选用的reranker模型
@@ -21,65 +13,170 @@ RERANKER_MAX_LENGTH = 1024
EMBEDDING_KEYWORD_FILE = "keywords.txt"
EMBEDDING_MODEL_OUTPUT_PATH = "output"
-# 要运行的 LLM 名称,可以包括本地模型和在线模型。列表中本地模型将在启动项目时全部加载。
-# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
-# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
-# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
+LLM_MODEL_CONFIG = {
+ # 意图识别不需要输出,模型后台知道就行
+ "preprocess_model": {
+ "zhipu-api": {
+ "temperature": 0.4,
+ "max_tokens": 2048,
+ "history_len": 100,
+ "prompt_name": "default",
+ "callbacks": False
+ },
+ },
+ "llm_model": {
+ "chatglm3-6b": {
+ "temperature": 0.9,
+ "max_tokens": 4096,
+ "history_len": 3,
+ "prompt_name": "default",
+ "callbacks": True
+ },
+ "zhipu-api": {
+ "temperature": 0.9,
+ "max_tokens": 4000,
+ "history_len": 5,
+ "prompt_name": "default",
+ "callbacks": True
+ },
+ "Qwen-1_8B-Chat": {
+ "temperature": 0.4,
+ "max_tokens": 2048,
+ "history_len": 100,
+ "prompt_name": "default",
+ "callbacks": False
+ },
+ },
+ "action_model": {
+ "chatglm3-6b": {
+ "temperature": 0.01,
+ "max_tokens": 4096,
+ "prompt_name": "ChatGLM3",
+ "callbacks": True
+ },
+ "zhipu-api": {
+ "temperature": 0.01,
+ "max_tokens": 2096,
+ "history_len": 5,
+ "prompt_name": "ChatGLM3",
+ "callbacks": True
+ },
+ },
+ "postprocess_model": {
+ "chatglm3-6b": {
+ "temperature": 0.01,
+ "max_tokens": 4096,
+ "prompt_name": "default",
+ "callbacks": True
+ }
+ },
-LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
-Agent_MODEL = None
+}
+
+TOOL_CONFIG = {
+ "search_local_knowledgebase": {
+ "use": True,
+ "top_k": 10,
+ "score_threshold": 1,
+ "conclude_prompt": {
+ "with_result":
+ '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",'
+ '不允许在答案中添加编造成分,答案请使用中文。 指令>\n'
+ '<已知信息>{{ context }}已知信息>\n'
+ '<问题>{{ question }}问题>\n',
+ "without_result":
+ '请你根据我的提问回答我的问题:\n'
+ '{{ question }}\n'
+ '请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n',
+ }
+ },
+ "search_internet": {
+ "use": True,
+ "search_engine_name": "bing",
+ "search_engine_config":
+ {
+ "bing": {
+ "result_len": 3,
+ "bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
+ "bing_key": "",
+ },
+ "metaphor": {
+ "result_len": 3,
+ "metaphor_api_key": "",
+ "split_result": False,
+ "chunk_size": 500,
+ "chunk_overlap": 0,
+ },
+ "duckduckgo": {
+ "result_len": 3
+ }
+ },
+ "top_k": 10,
+ "verbose": "Origin",
+ "conclude_prompt":
+ "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
+ "指令>\n<已知信息>{{ context }}已知信息>\n"
+ "<问题>\n"
+ "{{ question }}\n"
+ "问题>\n"
+ },
+ "arxiv": {
+ "use": True,
+ },
+ "shell": {
+ "use": True,
+ },
+ "weather_check": {
+ "use": True,
+ "api-key": "",
+ },
+ "search_youtube": {
+ "use": False,
+ },
+ "wolfram": {
+ "use": False,
+ },
+ "calculate": {
+ "use": False,
+ },
+
+}
# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
LLM_DEVICE = "auto"
-HISTORY_LEN = 3
-
-MAX_TOKENS = 2048
-
-TEMPERATURE = 0.7
-
ONLINE_LLM_MODEL = {
"openai-api": {
- "model_name": "gpt-4",
+ "model_name": "gpt-4-1106-preview",
"api_base_url": "https://api.openai.com/v1",
"api_key": "",
"openai_proxy": "",
},
-
- # 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": {
"api_key": "",
- "version": "glm-4",
+ "version": "chatglm_turbo",
"provider": "ChatGLMWorker",
},
-
- # 具体注册及api key获取请前往 https://api.minimax.chat/
"minimax-api": {
"group_id": "",
"api_key": "",
"is_pro": False,
"provider": "MiniMaxWorker",
},
-
- # 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
"xinghuo-api": {
"APPID": "",
"APISecret": "",
"api_key": "",
- "version": "v3.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.5","v3.0", "v2.0", "v1.5"
+ "version": "v3.0",
"provider": "XingHuoWorker",
},
-
- # 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
"qianfan-api": {
- "version": "ERNIE-Bot", # 注意大小写。当前支持 "ERNIE-Bot" 或 "ERNIE-Bot-turbo", 更多的见官方文档。
- "version_url": "", # 也可以不填写version,直接填写在千帆申请模型发布的API地址
+ "version": "ernie-bot-4",
+ "version_url": "",
"api_key": "",
"secret_key": "",
"provider": "QianFanWorker",
},
-
- # 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379
"fangzhou-api": {
"version": "chatglm-6b-model",
"version_url": "",
@@ -87,28 +184,22 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "FangZhouWorker",
},
-
- # 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
"qwen-api": {
"version": "qwen-max",
"api_key": "",
"provider": "QwenWorker",
"embed_model": "text-embedding-v1" # embedding 模型名称
},
-
- # 百川 API,申请方式请参考 https://www.baichuan-ai.com/home#api-enter
"baichuan-api": {
"version": "Baichuan2-53B",
"api_key": "",
"secret_key": "",
"provider": "BaiChuanWorker",
},
-
- # Azure API
"azure-api": {
- "deployment_name": "", # 部署容器的名字
- "resource_name": "", # https://{resource_name}.openai.azure.com/openai/ 填写resource_name的部分,其他部分不要填写
- "api_version": "", # API的版本,不是模型版本
+ "deployment_name": "",
+ "resource_name": "",
+ "api_version": "2023-07-01-preview",
"api_key": "",
"provider": "AzureWorker",
},
@@ -153,50 +244,33 @@ MODEL_PATH = {
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
- "bge-large-zh": "BAAI/bge-large-zh",
+ "bge-large-zh": "/media/zr/Data/Models/Embedding/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
- "bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
-
- "bge-m3": "BAAI/bge-m3",
-
+ "bge-large-zh-v1.5": "/Models/bge-large-zh-v1.5",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
- "nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
- "text-embedding-ada-002": "your OPENAI_API_KEY",
+ "nlp_gte_sentence-embedding_chinese-large": "/Models/nlp_gte_sentence-embedding_chinese-large",
+ "text-embedding-ada-002": "Just write your OpenAI key like "sk-o3IGBhC9g8AiFvTGWVKsT*****" ",
},
"llm_model": {
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
- "chatglm3-6b": "THUDM/chatglm3-6b",
+ "chatglm3-6b": "/Models/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
- "Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat",
- "Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin",
- "Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat",
+ "Yi-34B-Chat": "/data/share/models/Yi-34B-Chat",
+ "BlueLM-7B-Chat": "/Models/BlueLM-7B-Chat",
- "Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
- "Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
- "Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
+ "baichuan2-13b": "/media/zr/Data/Models/LLM/Baichuan2-13B-Chat",
+ "baichuan2-7b": "/media/zr/Data/Models/LLM/Baichuan2-7B-Chat",
+ "baichuan-7b": "baichuan-inc/Baichuan-7B",
+ "baichuan-13b": "baichuan-inc/Baichuan-13B",
+ 'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
- "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
- "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
- "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
- "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
-
- # Qwen1.5 模型 VLLM可能出现问题
- "Qwen1.5-0.5B-Chat": "Qwen/Qwen1.5-0.5B-Chat",
- "Qwen1.5-1.8B-Chat": "Qwen/Qwen1.5-1.8B-Chat",
- "Qwen1.5-4B-Chat": "Qwen/Qwen1.5-4B-Chat",
- "Qwen1.5-7B-Chat": "Qwen/Qwen1.5-7B-Chat",
- "Qwen1.5-14B-Chat": "Qwen/Qwen1.5-14B-Chat",
- "Qwen1.5-72B-Chat": "Qwen/Qwen1.5-72B-Chat",
-
- "baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
- "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
- "baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat",
- "baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat",
+ "aquila-7b": "BAAI/Aquila-7B",
+ "aquilachat-7b": "BAAI/AquilaChat-7B",
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
@@ -235,42 +309,43 @@ MODEL_PATH = {
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
+
+ "Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
+ "Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
+ "open_llama_13b": "openlm-research/open_llama_13b",
+ "vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
+ "koala": "young-geng/koala",
+
+ "mpt-7b": "mosaicml/mpt-7b",
+ "mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
+ "mpt-30b": "mosaicml/mpt-30b",
+ "opt-66b": "facebook/opt-66b",
+ "opt-iml-max-30b": "facebook/opt-iml-max-30b",
+
+ "Qwen-1_8B-Chat":"Qwen/Qwen-1_8B-Chat"
+ "Qwen-7B": "Qwen/Qwen-7B",
+ "Qwen-14B": "Qwen/Qwen-14B",
+ "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
+ "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
+ "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", # 确保已经安装了auto-gptq optimum flash-attn
+ "Qwen-14B-Chat-Int4": "/media/zr/Data/Models/LLM/Qwen-14B-Chat-Int4", # 确保已经安装了auto-gptq optimum flash-attn
},
- "reranker": {
- "bge-reranker-large": "BAAI/bge-reranker-large",
- "bge-reranker-base": "BAAI/bge-reranker-base",
- }
-}
-
-# 通常情况下不需要更改以下内容
-
-# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 使用VLLM可能导致模型推理能力下降,无法完成Agent任务
VLLM_MODEL_DICT = {
- "chatglm2-6b": "THUDM/chatglm2-6b",
- "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
- "chatglm3-6b": "THUDM/chatglm3-6b",
- "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
+ "aquila-7b": "BAAI/Aquila-7B",
+ "aquilachat-7b": "BAAI/AquilaChat-7B",
- "Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
- "Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
- "Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
+ "baichuan-7b": "baichuan-inc/Baichuan-7B",
+ "baichuan-13b": "baichuan-inc/Baichuan-13B",
+ 'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
- "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
- "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
- "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
- "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
-
- "baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
- "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
- "baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
- "baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
-
- "BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
- "BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
+ 'chatglm2-6b': 'THUDM/chatglm2-6b',
+ 'chatglm2-6b-32k': 'THUDM/chatglm2-6b-32k',
+ 'chatglm3-6b': 'THUDM/chatglm3-6b',
+ 'chatglm3-6b-32k': 'THUDM/chatglm3-6b-32k',
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
@@ -301,14 +376,13 @@ VLLM_MODEL_DICT = {
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
-}
+ "Qwen-7B": "Qwen/Qwen-7B",
+ "Qwen-14B": "Qwen/Qwen-14B",
+ "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
+ "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
-SUPPORT_AGENT_MODEL = [
- "openai-api", # GPT4 模型
- "qwen-api", # Qwen Max模型
- "zhipu-api", # 智谱AI GLM4模型
- "Qwen", # 所有Qwen系列本地模型
- "chatglm3-6b",
- "internlm2-chat-20b",
- "Orion-14B-Chat-Plugin",
-]
+ "agentlm-7b": "THUDM/agentlm-7b",
+ "agentlm-13b": "THUDM/agentlm-13b",
+ "agentlm-70b": "THUDM/agentlm-70b",
+
+}
diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example
index 6fb6996c..e7def5a6 100644
--- a/configs/prompt_config.py.example
+++ b/configs/prompt_config.py.example
@@ -1,26 +1,19 @@
-# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
-# 本配置文件支持热加载,修改prompt模板后无需重启服务。
-
-# LLM对话支持的变量:
-# - input: 用户输入内容
-
-# 知识库和搜索引擎对话支持的变量:
-# - context: 从检索结果拼接的知识文本
-# - question: 用户提出的问题
-
-# Agent对话支持的变量:
-
-# - tools: 可用的工具列表
-# - tool_names: 可用的工具名称列表
-# - history: 用户和Agent的对话历史
-# - input: 用户输入内容
-# - agent_scratchpad: Agent的思维记录
-
PROMPT_TEMPLATES = {
- "llm_chat": {
+ "preprocess_model": {
+ "default":
+ '请你根据我的描述和我们对话的历史,来判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。你只要回答一个数字。1 或者 0,1代表需要使用工具,0代表不需要使用工具。\n'
+ '以下几种情况要使用工具,请返回1\n'
+ '1. 实时性的问题,例如天气,日期,地点等信息\n'
+ '2. 需要数学计算的问题\n'
+ '3. 需要查询数据,地点等精确数据\n'
+ '4. 需要行业知识的问题\n'
+ ''
+ '{input}'
+ ''
+ },
+ "llm_model": {
"default":
'{{ input }}',
-
"with_history":
'The following is a friendly conversation between a human and an AI. '
'The AI is talkative and provides lots of specific details from its context. '
@@ -29,72 +22,42 @@ PROMPT_TEMPLATES = {
'{history}\n'
'Human: {input}\n'
'AI:',
-
- "py":
- '你是一个聪明的代码助手,请你给我写出简单的py代码。 \n'
- '{{ input }}',
},
-
-
- "knowledge_base_chat": {
- "default":
- '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,'
- '不允许在答案中添加编造成分,答案请使用中文。 指令>\n'
- '<已知信息>{{ context }}已知信息>\n'
- '<问题>{{ question }}问题>\n',
-
- "text":
- '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 指令>\n'
- '<已知信息>{{ context }}已知信息>\n'
- '<问题>{{ question }}问题>\n',
-
- "empty": # 搜不到知识库的时候使用
- '请你回答我的问题:\n'
- '{{ question }}\n\n',
- },
-
-
- "search_engine_chat": {
- "default":
- '<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。'
- '如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 指令>\n'
- '<已知信息>{{ context }}已知信息>\n'
- '<问题>{{ question }}问题>\n',
-
- "search":
- '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 指令>\n'
- '<已知信息>{{ context }}已知信息>\n'
- '<问题>{{ question }}问题>\n',
- },
-
-
- "agent_chat": {
- "default":
- 'Answer the following questions as best you can. If it is in order, you can use some tools appropriately. '
- 'You have access to the following tools:\n\n'
- '{tools}\n\n'
- 'Use the following format:\n'
- 'Question: the input question you must answer1\n'
- 'Thought: you should always think about what to do and what tools to use.\n'
- 'Action: the action to take, should be one of [{tool_names}]\n'
- 'Action Input: the input to the action\n'
+ "action_model": {
+ "GPT-4":
+ 'Answer the following questions as best you can. You have access to the following tools:\n'
+ 'The way you use the tools is by specifying a json blob.\n'
+ 'Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).\n'
+ 'The only values that should be in the "action" field are: {tool_names}\n'
+ 'The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:\n'
+ '```\n\n'
+ '{{{{\n'
+ ' "action": $TOOL_NAME,\n'
+ ' "action_input": $INPUT\n'
+ '}}}}\n'
+ '```\n\n'
+ 'ALWAYS use the following format:\n'
+ 'Question: the input question you must answer\n'
+ 'Thought: you should always think about what to do\n'
+ 'Action:\n'
+ '```\n\n'
+ '$JSON_BLOB'
+ '```\n\n'
'Observation: the result of the action\n'
- '... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n'
+ '... (this Thought/Action/Observation can repeat N times)\n'
'Thought: I now know the final answer\n'
'Final Answer: the final answer to the original input question\n'
- 'Begin!\n\n'
- 'history: {history}\n\n'
- 'Question: {input}\n\n'
- 'Thought: {agent_scratchpad}\n',
-
+ 'Begin! Reminder to always use the exact characters `Final Answer` when responding.\n'
+ 'history: {history}\n'
+ 'Question:{input}\n'
+ 'Thought:{agent_scratchpad}\n',
"ChatGLM3":
- 'You can answer using the tools, or answer directly using your knowledge without using the tools. '
- 'Respond to the human as helpfully and accurately as possible.\n'
+ 'You can answer using the tools.Respond to the human as helpfully and accurately as possible.\n'
'You have access to the following tools:\n'
'{tools}\n'
- 'Use a json blob to specify a tool by providing an action key (tool name) '
+ 'Use a json blob to specify a tool by providing an action key (tool name)\n'
'and an action_input key (tool input).\n'
- 'Valid "action" values: "Final Answer" or [{tool_names}]'
+ 'Valid "action" values: "Final Answer" or [{tool_names}]\n'
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
'```\n'
'{{{{\n'
@@ -118,10 +81,13 @@ PROMPT_TEMPLATES = {
' "action": "Final Answer",\n'
' "action_input": "Final response to human"\n'
'}}}}\n'
- 'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. '
+ 'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary.\n'
'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n'
'history: {history}\n\n'
'Question: {input}\n\n'
- 'Thought: {agent_scratchpad}',
+ 'Thought: {agent_scratchpad}\n',
+ },
+ "postprocess_model": {
+ "default": "{{input}}",
}
}
diff --git a/configs/server_config.py.example b/configs/server_config.py.example
index 9bbb8b49..3cc51dd5 100644
--- a/configs/server_config.py.example
+++ b/configs/server_config.py.example
@@ -57,7 +57,7 @@ FSCHAT_MODEL_WORKERS = {
# "awq_ckpt": None,
# "awq_wbits": 16,
# "awq_groupsize": -1,
- # "model_names": LLM_MODELS,
+ # "model_names": None,
# "conv_template": None,
# "limit_worker_concurrency": 5,
# "stream_interval": 2,
diff --git a/docs/ES部署指南.md b/docs/ES部署指南.md
deleted file mode 100644
index f4615826..00000000
--- a/docs/ES部署指南.md
+++ /dev/null
@@ -1,29 +0,0 @@
-
-# 实现基于ES的数据插入、检索、删除、更新
-```shell
-author: 唐国梁Tommy
-e-mail: flytang186@qq.com
-
-如果遇到任何问题,可以与我联系,我这边部署后服务是没有问题的。
-```
-
-## 第1步:ES docker部署
-```shell
-docker network create elastic
-docker run -id --name elasticsearch --net elastic -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "xpack.security.enabled=false" -e "xpack.security.http.ssl.enabled=false" -t docker.elastic.co/elasticsearch/elasticsearch:8.8.2
-```
-
-### 第2步:Kibana docker部署
-**注意:Kibana版本与ES保持一致**
-```shell
-docker pull docker.elastic.co/kibana/kibana:{version}
-docker run --name kibana --net elastic -p 5601:5601 docker.elastic.co/kibana/kibana:{version}
-```
-
-### 第3步:核心代码
-```shell
-1. 核心代码路径
-server/knowledge_base/kb_service/es_kb_service.py
-
-2. 需要在 configs/model_config.py 中 配置 ES参数(IP, PORT)等;
-```
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 30b82248..bb341dd6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,11 +1,7 @@
-torch==2.1.2
-torchvision==0.16.2
-torchaudio==2.1.2
-xformers==0.0.23.post1
-transformers==4.37.2
-sentence_transformers==2.2.2
-langchain==0.0.354
-langchain-experimental==0.0.47
+# API requirements
+
+langchain>=0.0.346
+langchain-experimental>=0.0.42
pydantic==1.10.13
fschat==0.2.35
openai==1.9.0
diff --git a/requirements_api.txt b/requirements_api.txt
index fbb00357..de38487b 100644
--- a/requirements_api.txt
+++ b/requirements_api.txt
@@ -1,11 +1,7 @@
-torch~=2.1.2
-torchvision~=0.16.2
-torchaudio~=2.1.2
-xformers>=0.0.23.post1
-transformers==4.37.2
-sentence_transformers==2.2.2
-langchain==0.0.354
-langchain-experimental==0.0.47
+# API requirements
+
+langchain>=0.0.346
+langchain-experimental>=0.0.42
pydantic==1.10.13
fschat==0.2.35
openai~=1.9.0
diff --git a/server/chat/__init__.py b/server/__init__.py
similarity index 100%
rename from server/chat/__init__.py
rename to server/__init__.py
diff --git a/server/agent/__init__.py b/server/agent/__init__.py
deleted file mode 100644
index 0de21612..00000000
--- a/server/agent/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .model_contain import *
-from .callbacks import *
-from .custom_template import *
-from .tools import *
\ No newline at end of file
diff --git a/server/agent/agent_factory/__init__.py b/server/agent/agent_factory/__init__.py
new file mode 100644
index 00000000..4729de78
--- /dev/null
+++ b/server/agent/agent_factory/__init__.py
@@ -0,0 +1 @@
+from .glm3_agent import initialize_glm3_agent
\ No newline at end of file
diff --git a/server/agent/custom_agent/ChatGLM3Agent.py b/server/agent/agent_factory/glm3_agent.py
similarity index 94%
rename from server/agent/custom_agent/ChatGLM3Agent.py
rename to server/agent/agent_factory/glm3_agent.py
index 65f57567..00c4b5ce 100644
--- a/server/agent/custom_agent/ChatGLM3Agent.py
+++ b/server/agent/agent_factory/glm3_agent.py
@@ -3,6 +3,14 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file
"""
from __future__ import annotations
+import yaml
+from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
+from langchain.memory import ConversationBufferWindowMemory
+from typing import Any, List, Sequence, Tuple, Optional, Union
+import os
+from langchain.agents.agent import Agent
+from langchain.chains.llm import LLMChain
+from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
import json
import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
@@ -22,6 +30,7 @@ from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
+from pydantic.schema import model_schema
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__)
@@ -42,8 +51,9 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
if "tool_call" in text:
action_end = text.find("```")
action = text[:action_end].strip()
+
params_str_start = text.find("(") + 1
- params_str_end = text.rfind(")")
+ params_str_end = text.find(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
@@ -225,4 +235,4 @@ def initialize_glm3_agent(
memory=memory,
tags=tags_,
**kwargs,
- )
\ No newline at end of file
+ )
diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py
deleted file mode 100644
index 1431d2a0..00000000
--- a/server/agent/custom_template.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from __future__ import annotations
-from langchain.agents import Tool, AgentOutputParser
-from langchain.prompts import StringPromptTemplate
-from typing import List
-from langchain.schema import AgentAction, AgentFinish
-
-from configs import SUPPORT_AGENT_MODEL
-from server.agent import model_container
-class CustomPromptTemplate(StringPromptTemplate):
- template: str
- tools: List[Tool]
-
- def format(self, **kwargs) -> str:
- intermediate_steps = kwargs.pop("intermediate_steps")
- thoughts = ""
- for action, observation in intermediate_steps:
- thoughts += action.log
- thoughts += f"\nObservation: {observation}\nThought: "
- kwargs["agent_scratchpad"] = thoughts
- kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
- kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
- return self.template.format(**kwargs)
-
-class CustomOutputParser(AgentOutputParser):
- begin: bool = False
- def __init__(self):
- super().__init__()
- self.begin = True
-
- def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
- if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
- self.begin = False
- stop_words = ["Observation:"]
- min_index = len(llm_output)
- for stop_word in stop_words:
- index = llm_output.find(stop_word)
- if index != -1 and index < min_index:
- min_index = index
- llm_output = llm_output[:min_index]
-
- if "Final Answer:" in llm_output:
- self.begin = True
- return AgentFinish(
- return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
- log=llm_output,
- )
- parts = llm_output.split("Action:")
- if len(parts) < 2:
- return AgentFinish(
- return_values={"output": f"调用agent工具失败,该回答为大模型自身能力的回答:\n\n `{llm_output}`"},
- log=llm_output,
- )
-
- action = parts[1].split("Action Input:")[0].strip()
- action_input = parts[1].split("Action Input:")[1].strip()
- try:
- ans = AgentAction(
- tool=action,
- tool_input=action_input.strip(" ").strip('"'),
- log=llm_output
- )
- return ans
- except:
- return AgentFinish(
- return_values={"output": f"调用agent失败: `{llm_output}`"},
- log=llm_output,
- )
diff --git a/server/agent/model_contain.py b/server/agent/model_contain.py
deleted file mode 100644
index 0141ad03..00000000
--- a/server/agent/model_contain.py
+++ /dev/null
@@ -1,6 +0,0 @@
-class ModelContainer:
- def __init__(self):
- self.MODEL = None
- self.DATABASE = None
-
-model_container = ModelContainer()
diff --git a/server/agent/tools/__init__.py b/server/agent/tools/__init__.py
deleted file mode 100644
index 21dfd973..00000000
--- a/server/agent/tools/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-## 导入所有的工具类
-from .search_knowledgebase_simple import search_knowledgebase_simple
-from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
-from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
-from .calculate import calculate, CalculatorInput
-from .weather_check import weathercheck, WeatherInput
-from .shell import shell, ShellInput
-from .search_internet import search_internet, SearchInternetInput
-from .wolfram import wolfram, WolframInput
-from .search_youtube import search_youtube, YoutubeInput
-from .arxiv import arxiv, ArxivInput
diff --git a/server/agent/tools/calculate.py b/server/agent/tools/calculate.py
deleted file mode 100644
index bb0cbcca..00000000
--- a/server/agent/tools/calculate.py
+++ /dev/null
@@ -1,76 +0,0 @@
-from langchain.prompts import PromptTemplate
-from langchain.chains import LLMMathChain
-from server.agent import model_container
-from pydantic import BaseModel, Field
-
-_PROMPT_TEMPLATE = """
-将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
-问题: ${{包含数学问题的问题。}}
-```text
-${{解决问题的单行数学表达式}}
-```
-...numexpr.evaluate(query)...
-```output
-${{运行代码的输出}}
-```
-答案: ${{答案}}
-
-这是两个例子:
-
-问题: 37593 * 67是多少?
-```text
-37593 * 67
-```
-...numexpr.evaluate("37593 * 67")...
-```output
-2518731
-
-答案: 2518731
-
-问题: 37593的五次方根是多少?
-```text
-37593**(1/5)
-```
-...numexpr.evaluate("37593**(1/5)")...
-```output
-8.222831614237718
-
-答案: 8.222831614237718
-
-
-问题: 2的平方是多少?
-```text
-2 ** 2
-```
-...numexpr.evaluate("2 ** 2")...
-```output
-4
-
-答案: 4
-
-
-现在,这是我的问题:
-问题: {question}
-"""
-
-PROMPT = PromptTemplate(
- input_variables=["question"],
- template=_PROMPT_TEMPLATE,
-)
-
-
-class CalculatorInput(BaseModel):
- query: str = Field()
-
-def calculate(query: str):
- model = model_container.MODEL
- llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
- ans = llm_math.run(query)
- return ans
-
-if __name__ == "__main__":
- result = calculate("2的三次方")
- print("答案:",result)
-
-
-
diff --git a/server/agent/tools/search_internet.py b/server/agent/tools/search_internet.py
deleted file mode 100644
index 48a8a629..00000000
--- a/server/agent/tools/search_internet.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import json
-from server.chat.search_engine_chat import search_engine_chat
-from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
-import asyncio
-from server.agent import model_container
-from pydantic import BaseModel, Field
-
-async def search_engine_iter(query: str):
- response = await search_engine_chat(query=query,
- search_engine_name="bing", # 这里切换搜索引擎
- model_name=model_container.MODEL.model_name,
- temperature=0.01, # Agent 搜索互联网的时候,温度设置为0.01
- history=[],
- top_k = VECTOR_SEARCH_TOP_K,
- max_tokens= MAX_TOKENS,
- prompt_name = "default",
- stream=False)
-
- contents = ""
-
- async for data in response.body_iterator: # 这里的data是一个json字符串
- data = json.loads(data)
- contents = data["answer"]
- docs = data["docs"]
-
- return contents
-
-def search_internet(query: str):
- return asyncio.run(search_engine_iter(query))
-
-class SearchInternetInput(BaseModel):
- location: str = Field(description="Query for Internet search")
-
-
-if __name__ == "__main__":
- result = search_internet("今天星期几")
- print("答案:",result)
diff --git a/server/agent/tools/search_knowledgebase_complex.py b/server/agent/tools/search_knowledgebase_complex.py
deleted file mode 100644
index af4d9116..00000000
--- a/server/agent/tools/search_knowledgebase_complex.py
+++ /dev/null
@@ -1,287 +0,0 @@
-from __future__ import annotations
-import json
-import re
-import warnings
-from typing import Dict
-from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
-from langchain.chains.llm import LLMChain
-from langchain.pydantic_v1 import Extra, root_validator
-from langchain.schema import BasePromptTemplate
-from langchain.schema.language_model import BaseLanguageModel
-from typing import List, Any, Optional
-from langchain.prompts import PromptTemplate
-from server.chat.knowledge_base_chat import knowledge_base_chat
-from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
-import asyncio
-from server.agent import model_container
-from pydantic import BaseModel, Field
-
-async def search_knowledge_base_iter(database: str, query: str) -> str:
- response = await knowledge_base_chat(query=query,
- knowledge_base_name=database,
- model_name=model_container.MODEL.model_name,
- temperature=0.01,
- history=[],
- top_k=VECTOR_SEARCH_TOP_K,
- max_tokens=MAX_TOKENS,
- prompt_name="default",
- score_threshold=SCORE_THRESHOLD,
- stream=False)
-
- contents = ""
- async for data in response.body_iterator: # 这里的data是一个json字符串
- data = json.loads(data)
- contents += data["answer"]
- docs = data["docs"]
- return contents
-
-
-async def search_knowledge_multiple(queries) -> List[str]:
- # queries 应该是一个包含多个 (database, query) 元组的列表
- tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
- results = await asyncio.gather(*tasks)
- # 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
- combined_results = []
- for (database, _), result in zip(queries, results):
- message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
- combined_results.append(message)
-
- return combined_results
-
-
-def search_knowledge(queries) -> str:
- responses = asyncio.run(search_knowledge_multiple(queries))
- # 输出每个整合的查询结果
- contents = ""
- for response in responses:
- contents += response + "\n\n"
- return contents
-
-
-_PROMPT_TEMPLATE = """
-用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
-
-对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
-
-例子:
-
-robotic,机器人男女比例是多少
-bigdata,大数据的就业情况如何
-
-
-这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
-
-
-{database_names}
-
-你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
-不要输出中文的逗号,不要输出引号。
-
-Question: ${{用户的问题}}
-
-```text
-${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}}
-
-```output
-数据库查询的结果
-
-现在,我们开始作答
-问题: {question}
-"""
-
-PROMPT = PromptTemplate(
- input_variables=["question", "database_names"],
- template=_PROMPT_TEMPLATE,
-)
-
-
-class LLMKnowledgeChain(LLMChain):
- llm_chain: LLMChain
- llm: Optional[BaseLanguageModel] = None
- """[Deprecated] LLM wrapper to use."""
- prompt: BasePromptTemplate = PROMPT
- """[Deprecated] Prompt to use to translate to python if necessary."""
- database_names: Dict[str, str] = None
- input_key: str = "question" #: :meta private:
- output_key: str = "answer" #: :meta private:
-
- class Config:
- """Configuration for this pydantic object."""
-
- extra = Extra.forbid
- arbitrary_types_allowed = True
-
- @root_validator(pre=True)
- def raise_deprecation(cls, values: Dict) -> Dict:
- if "llm" in values:
- warnings.warn(
- "Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
- "Please instantiate with llm_chain argument or using the from_llm "
- "class method."
- )
- if "llm_chain" not in values and values["llm"] is not None:
- prompt = values.get("prompt", PROMPT)
- values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
- return values
-
- @property
- def input_keys(self) -> List[str]:
- """Expect input key.
-
- :meta private:
- """
- return [self.input_key]
-
- @property
- def output_keys(self) -> List[str]:
- """Expect output key.
-
- :meta private:
- """
- return [self.output_key]
-
- def _evaluate_expression(self, queries) -> str:
- try:
- output = search_knowledge(queries)
- except Exception as e:
- output = "输入的信息有误或不存在知识库,错误信息如下:\n"
- return output + str(e)
- return output
-
- def _process_llm_result(
- self,
- llm_output: str,
- run_manager: CallbackManagerForChainRun
- ) -> Dict[str, str]:
-
- run_manager.on_text(llm_output, color="green", verbose=self.verbose)
-
- llm_output = llm_output.strip()
- # text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
- text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
- if text_match:
- expression = text_match.group(1).strip()
- cleaned_input_str = (expression.replace("\"", "").replace("“", "").
- replace("”", "").replace("```", "").strip())
- lines = cleaned_input_str.split("\n")
- # 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
-
- try:
- queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
- except:
- queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
- run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
- output = self._evaluate_expression(queries)
- run_manager.on_text("\nAnswer: ", verbose=self.verbose)
- run_manager.on_text(output, color="yellow", verbose=self.verbose)
- answer = "Answer: " + output
- elif llm_output.startswith("Answer:"):
- answer = llm_output
- elif "Answer:" in llm_output:
- answer = llm_output.split("Answer:")[-1]
- else:
- return {self.output_key: f"输入的格式不对:\n {llm_output}"}
- return {self.output_key: answer}
-
- async def _aprocess_llm_result(
- self,
- llm_output: str,
- run_manager: AsyncCallbackManagerForChainRun,
- ) -> Dict[str, str]:
- await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
- llm_output = llm_output.strip()
- text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
- if text_match:
-
- expression = text_match.group(1).strip()
- cleaned_input_str = (
- expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip())
- lines = cleaned_input_str.split("\n")
- try:
- queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
- except:
- queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
- await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
- verbose=self.verbose)
-
- output = self._evaluate_expression(queries)
- await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
- await run_manager.on_text(output, color="yellow", verbose=self.verbose)
- answer = "Answer: " + output
- elif llm_output.startswith("Answer:"):
- answer = llm_output
- elif "Answer:" in llm_output:
- answer = "Answer: " + llm_output.split("Answer:")[-1]
- else:
- raise ValueError(f"unknown format from LLM: {llm_output}")
- return {self.output_key: answer}
-
- def _call(
- self,
- inputs: Dict[str, str],
- run_manager: Optional[CallbackManagerForChainRun] = None,
- ) -> Dict[str, str]:
- _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
- _run_manager.on_text(inputs[self.input_key])
- self.database_names = model_container.DATABASE
- data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
- llm_output = self.llm_chain.predict(
- database_names=data_formatted_str,
- question=inputs[self.input_key],
- stop=["```output"],
- callbacks=_run_manager.get_child(),
- )
- return self._process_llm_result(llm_output, _run_manager)
-
- async def _acall(
- self,
- inputs: Dict[str, str],
- run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
- ) -> Dict[str, str]:
- _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
- await _run_manager.on_text(inputs[self.input_key])
- self.database_names = model_container.DATABASE
- data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
- llm_output = await self.llm_chain.apredict(
- database_names=data_formatted_str,
- question=inputs[self.input_key],
- stop=["```output"],
- callbacks=_run_manager.get_child(),
- )
- return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
-
- @property
- def _chain_type(self) -> str:
- return "llm_knowledge_chain"
-
- @classmethod
- def from_llm(
- cls,
- llm: BaseLanguageModel,
- prompt: BasePromptTemplate = PROMPT,
- **kwargs: Any,
- ) -> LLMKnowledgeChain:
- llm_chain = LLMChain(llm=llm, prompt=prompt)
- return cls(llm_chain=llm_chain, **kwargs)
-
-
-def search_knowledgebase_complex(query: str):
- model = model_container.MODEL
- llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
- ans = llm_knowledge.run(query)
- return ans
-
-class KnowledgeSearchInput(BaseModel):
- location: str = Field(description="The query to be searched")
-
-if __name__ == "__main__":
- result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
- print(result)
-
-# 这是一个正常的切割
-# queries = [
-# ("bigdata", "大数据专业的男女比例"),
-# ("robotic", "机器人专业的优势")
-# ]
-# result = search_knowledge(queries)
-# print(result)
diff --git a/server/agent/tools/search_knowledgebase_once.py b/server/agent/tools/search_knowledgebase_once.py
deleted file mode 100644
index c9a2d7b5..00000000
--- a/server/agent/tools/search_knowledgebase_once.py
+++ /dev/null
@@ -1,234 +0,0 @@
-from __future__ import annotations
-import re
-import warnings
-from typing import Dict
-
-from langchain.callbacks.manager import (
- AsyncCallbackManagerForChainRun,
- CallbackManagerForChainRun,
-)
-from langchain.chains.llm import LLMChain
-from langchain.pydantic_v1 import Extra, root_validator
-from langchain.schema import BasePromptTemplate
-from langchain.schema.language_model import BaseLanguageModel
-from typing import List, Any, Optional
-from langchain.prompts import PromptTemplate
-import sys
-import os
-import json
-
-sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
-from server.chat.knowledge_base_chat import knowledge_base_chat
-from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
-
-import asyncio
-from server.agent import model_container
-from pydantic import BaseModel, Field
-
-async def search_knowledge_base_iter(database: str, query: str):
- response = await knowledge_base_chat(query=query,
- knowledge_base_name=database,
- model_name=model_container.MODEL.model_name,
- temperature=0.01,
- history=[],
- top_k=VECTOR_SEARCH_TOP_K,
- max_tokens=MAX_TOKENS,
- prompt_name="knowledge_base_chat",
- score_threshold=SCORE_THRESHOLD,
- stream=False)
-
- contents = ""
- async for data in response.body_iterator: # 这里的data是一个json字符串
- data = json.loads(data)
- contents += data["answer"]
- docs = data["docs"]
- return contents
-
-
-_PROMPT_TEMPLATE = """
-用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
-Question: ${{用户的问题}}
-这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
-
-{database_names}
-
-你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
-```text
-${{知识库的名称}}
-```
-```output
-数据库查询的结果
-```
-答案: ${{答案}}
-
-现在,这是我的问题:
-问题: {question}
-
-"""
-PROMPT = PromptTemplate(
- input_variables=["question", "database_names"],
- template=_PROMPT_TEMPLATE,
-)
-
-
-class LLMKnowledgeChain(LLMChain):
- llm_chain: LLMChain
- llm: Optional[BaseLanguageModel] = None
- """[Deprecated] LLM wrapper to use."""
- prompt: BasePromptTemplate = PROMPT
- """[Deprecated] Prompt to use to translate to python if necessary."""
- database_names: Dict[str, str] = model_container.DATABASE
- input_key: str = "question" #: :meta private:
- output_key: str = "answer" #: :meta private:
-
- class Config:
- """Configuration for this pydantic object."""
-
- extra = Extra.forbid
- arbitrary_types_allowed = True
-
- @root_validator(pre=True)
- def raise_deprecation(cls, values: Dict) -> Dict:
- if "llm" in values:
- warnings.warn(
- "Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
- "Please instantiate with llm_chain argument or using the from_llm "
- "class method."
- )
- if "llm_chain" not in values and values["llm"] is not None:
- prompt = values.get("prompt", PROMPT)
- values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
- return values
-
- @property
- def input_keys(self) -> List[str]:
- """Expect input key.
-
- :meta private:
- """
- return [self.input_key]
-
- @property
- def output_keys(self) -> List[str]:
- """Expect output key.
-
- :meta private:
- """
- return [self.output_key]
-
- def _evaluate_expression(self, dataset, query) -> str:
- try:
- output = asyncio.run(search_knowledge_base_iter(dataset, query))
- except Exception as e:
- output = "输入的信息有误或不存在知识库"
- return output
- return output
-
- def _process_llm_result(
- self,
- llm_output: str,
- llm_input: str,
- run_manager: CallbackManagerForChainRun
- ) -> Dict[str, str]:
-
- run_manager.on_text(llm_output, color="green", verbose=self.verbose)
-
- llm_output = llm_output.strip()
- text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
- if text_match:
- database = text_match.group(1).strip()
- output = self._evaluate_expression(database, llm_input)
- run_manager.on_text("\nAnswer: ", verbose=self.verbose)
- run_manager.on_text(output, color="yellow", verbose=self.verbose)
- answer = "Answer: " + output
- elif llm_output.startswith("Answer:"):
- answer = llm_output
- elif "Answer:" in llm_output:
- answer = "Answer: " + llm_output.split("Answer:")[-1]
- else:
- return {self.output_key: f"输入的格式不对: {llm_output}"}
- return {self.output_key: answer}
-
- async def _aprocess_llm_result(
- self,
- llm_output: str,
- run_manager: AsyncCallbackManagerForChainRun,
- ) -> Dict[str, str]:
- await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
- llm_output = llm_output.strip()
- text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
- if text_match:
- expression = text_match.group(1)
- output = self._evaluate_expression(expression)
- await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
- await run_manager.on_text(output, color="yellow", verbose=self.verbose)
- answer = "Answer: " + output
- elif llm_output.startswith("Answer:"):
- answer = llm_output
- elif "Answer:" in llm_output:
- answer = "Answer: " + llm_output.split("Answer:")[-1]
- else:
- raise ValueError(f"unknown format from LLM: {llm_output}")
- return {self.output_key: answer}
-
- def _call(
- self,
- inputs: Dict[str, str],
- run_manager: Optional[CallbackManagerForChainRun] = None,
- ) -> Dict[str, str]:
- _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
- _run_manager.on_text(inputs[self.input_key])
- data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
- llm_output = self.llm_chain.predict(
- database_names=data_formatted_str,
- question=inputs[self.input_key],
- stop=["```output"],
- callbacks=_run_manager.get_child(),
- )
- return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
-
- async def _acall(
- self,
- inputs: Dict[str, str],
- run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
- ) -> Dict[str, str]:
- _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
- await _run_manager.on_text(inputs[self.input_key])
- data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
- llm_output = await self.llm_chain.apredict(
- database_names=data_formatted_str,
- question=inputs[self.input_key],
- stop=["```output"],
- callbacks=_run_manager.get_child(),
- )
- return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
-
- @property
- def _chain_type(self) -> str:
- return "llm_knowledge_chain"
-
- @classmethod
- def from_llm(
- cls,
- llm: BaseLanguageModel,
- prompt: BasePromptTemplate = PROMPT,
- **kwargs: Any,
- ) -> LLMKnowledgeChain:
- llm_chain = LLMChain(llm=llm, prompt=prompt)
- return cls(llm_chain=llm_chain, **kwargs)
-
-
-def search_knowledgebase_once(query: str):
- model = model_container.MODEL
- llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
- ans = llm_knowledge.run(query)
- return ans
-
-
-class KnowledgeSearchInput(BaseModel):
- location: str = Field(description="The query to be searched")
-
-
-if __name__ == "__main__":
- result = search_knowledgebase_once("大数据的男女比例")
- print(result)
diff --git a/server/agent/tools/search_knowledgebase_simple.py b/server/agent/tools/search_knowledgebase_simple.py
deleted file mode 100644
index 65d7df3b..00000000
--- a/server/agent/tools/search_knowledgebase_simple.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from server.chat.knowledge_base_chat import knowledge_base_chat
-from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
-import json
-import asyncio
-from server.agent import model_container
-
-async def search_knowledge_base_iter(database: str, query: str) -> str:
- response = await knowledge_base_chat(query=query,
- knowledge_base_name=database,
- model_name=model_container.MODEL.model_name,
- temperature=0.01,
- history=[],
- top_k=VECTOR_SEARCH_TOP_K,
- max_tokens=MAX_TOKENS,
- prompt_name="knowledge_base_chat",
- score_threshold=SCORE_THRESHOLD,
- stream=False)
-
- contents = ""
- async for data in response.body_iterator: # 这里的data是一个json字符串
- data = json.loads(data)
- contents = data["answer"]
- docs = data["docs"]
- return contents
-
-def search_knowledgebase_simple(query: str):
- return asyncio.run(search_knowledge_base_iter(query))
-
-
-if __name__ == "__main__":
- result = search_knowledgebase_simple("大数据男女比例")
- print("答案:",result)
\ No newline at end of file
diff --git a/server/agent/tools_factory/__init__.py b/server/agent/tools_factory/__init__.py
new file mode 100644
index 00000000..461ce311
--- /dev/null
+++ b/server/agent/tools_factory/__init__.py
@@ -0,0 +1,8 @@
+from .search_local_knowledgebase import search_local_knowledgebase, SearchKnowledgeInput
+from .calculate import calculate, CalculatorInput
+from .weather_check import weather_check, WeatherInput
+from .shell import shell, ShellInput
+from .search_internet import search_internet, SearchInternetInput
+from .wolfram import wolfram, WolframInput
+from .search_youtube import search_youtube, YoutubeInput
+from .arxiv import arxiv, ArxivInput
diff --git a/server/agent/tools/arxiv.py b/server/agent/tools_factory/arxiv.py
similarity index 100%
rename from server/agent/tools/arxiv.py
rename to server/agent/tools_factory/arxiv.py
diff --git a/server/agent/tools_factory/calculate.py b/server/agent/tools_factory/calculate.py
new file mode 100644
index 00000000..a47d65ca
--- /dev/null
+++ b/server/agent/tools_factory/calculate.py
@@ -0,0 +1,23 @@
+from pydantic import BaseModel, Field
+
+def calculate(a: float, b: float, operator: str) -> float:
+ if operator == "+":
+ return a + b
+ elif operator == "-":
+ return a - b
+ elif operator == "*":
+ return a * b
+ elif operator == "/":
+ if b != 0:
+ return a / b
+ else:
+ return float('inf') # 防止除以零
+ elif operator == "^":
+ return a ** b
+ else:
+ raise ValueError("Unsupported operator")
+
+class CalculatorInput(BaseModel):
+ a: float = Field(description="first number")
+ b: float = Field(description="second number")
+ operator: str = Field(description="operator to use (e.g., +, -, *, /, ^)")
diff --git a/server/agent/tools_factory/search_internet.py b/server/agent/tools_factory/search_internet.py
new file mode 100644
index 00000000..68a5c6b1
--- /dev/null
+++ b/server/agent/tools_factory/search_internet.py
@@ -0,0 +1,99 @@
+from pydantic import BaseModel, Field
+from langchain.utilities.bing_search import BingSearchAPIWrapper
+from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
+from configs import TOOL_CONFIG
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from typing import List, Dict
+from langchain.docstore.document import Document
+from strsimpy.normalized_levenshtein import NormalizedLevenshtein
+from markdownify import markdownify
+
+
+def bing_search(text, config):
+ search = BingSearchAPIWrapper(bing_subscription_key=config["bing_key"],
+ bing_search_url=config["bing_search_url"])
+ return search.results(text, config["result_len"])
+
+
+def duckduckgo_search(text, config):
+ search = DuckDuckGoSearchAPIWrapper()
+ return search.results(text, config["result_len"])
+
+
+def metaphor_search(
+ text: str,
+ config: dict,
+) -> List[Dict]:
+ from metaphor_python import Metaphor
+ client = Metaphor(config["metaphor_api_key"])
+ search = client.search(text, num_results=config["result_len"], use_autoprompt=True)
+ contents = search.get_contents().contents
+ for x in contents:
+ x.extract = markdownify(x.extract)
+ if config["split_result"]:
+ docs = [Document(page_content=x.extract,
+ metadata={"link": x.url, "title": x.title})
+ for x in contents]
+ text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
+ chunk_size=config["chunk_size"],
+ chunk_overlap=config["chunk_overlap"])
+ splitted_docs = text_splitter.split_documents(docs)
+ if len(splitted_docs) > config["result_len"]:
+ normal = NormalizedLevenshtein()
+ for x in splitted_docs:
+ x.metadata["score"] = normal.similarity(text, x.page_content)
+ splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
+ splitted_docs = splitted_docs[:config["result_len"]]
+
+ docs = [{"snippet": x.page_content,
+ "link": x.metadata["link"],
+ "title": x.metadata["title"]}
+ for x in splitted_docs]
+ else:
+ docs = [{"snippet": x.extract,
+ "link": x.url,
+ "title": x.title}
+ for x in contents]
+
+ return docs
+
+
+SEARCH_ENGINES = {"bing": bing_search,
+ "duckduckgo": duckduckgo_search,
+ "metaphor": metaphor_search,
+ }
+
+
+def search_result2docs(search_results):
+ docs = []
+ for result in search_results:
+ doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
+ metadata={"source": result["link"] if "link" in result.keys() else "",
+ "filename": result["title"] if "title" in result.keys() else ""})
+ docs.append(doc)
+ return docs
+
+
+def search_engine(query: str,
+ config: dict):
+ search_engine_use = SEARCH_ENGINES[config["search_engine_name"]]
+ results = search_engine_use(text=query,
+ config=config["search_engine_config"][
+ config["search_engine_name"]])
+ docs = search_result2docs(results)
+ context = ""
+ docs = [
+ f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
+ for inum, doc in enumerate(docs)
+ ]
+
+ for doc in docs:
+ context += doc + "\n"
+ return context
+def search_internet(query: str):
+ tool_config = TOOL_CONFIG["search_internet"]
+ return search_engine(query=query, config=tool_config)
+
+class SearchInternetInput(BaseModel):
+ query: str = Field(description="Query for Internet search")
+
diff --git a/server/agent/tools_factory/search_local_knowledgebase.py b/server/agent/tools_factory/search_local_knowledgebase.py
new file mode 100644
index 00000000..0bc7803c
--- /dev/null
+++ b/server/agent/tools_factory/search_local_knowledgebase.py
@@ -0,0 +1,40 @@
+from urllib.parse import urlencode
+
+from pydantic import BaseModel, Field
+
+from server.knowledge_base.kb_doc_api import search_docs
+from configs import TOOL_CONFIG
+
+
+def search_knowledgebase(query: str, database: str, config: dict):
+ docs = search_docs(
+ query=query,
+ knowledge_base_name=database,
+ top_k=config["top_k"],
+ score_threshold=config["score_threshold"])
+ context = ""
+ source_documents = []
+ for inum, doc in enumerate(docs):
+ filename = doc.metadata.get("source")
+ parameters = urlencode({"knowledge_base_name": database, "file_name": filename})
+ url = f"download_doc?" + parameters
+ text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
+ source_documents.append(text)
+
+ if len(source_documents) == 0:
+ context= "没有找到相关文档,请更换关键词重试"
+ else:
+ for doc in source_documents:
+ context += doc + "\n"
+
+ return context
+
+
+class SearchKnowledgeInput(BaseModel):
+ database: str = Field(description="Database for Knowledge Search")
+ query: str = Field(description="Query for Knowledge Search")
+
+
+def search_local_knowledgebase(database: str, query: str):
+ tool_config = TOOL_CONFIG["search_local_knowledgebase"]
+ return search_knowledgebase(query=query, database=database, config=tool_config)
diff --git a/server/agent/tools/search_youtube.py b/server/agent/tools_factory/search_youtube.py
similarity index 80%
rename from server/agent/tools/search_youtube.py
rename to server/agent/tools_factory/search_youtube.py
index f02b3625..57049897 100644
--- a/server/agent/tools/search_youtube.py
+++ b/server/agent/tools_factory/search_youtube.py
@@ -6,4 +6,4 @@ def search_youtube(query: str):
return tool.run(tool_input=query)
class YoutubeInput(BaseModel):
- location: str = Field(description="Query for Videos search")
\ No newline at end of file
+ query: str = Field(description="Query for Videos search")
\ No newline at end of file
diff --git a/server/agent/tools/shell.py b/server/agent/tools_factory/shell.py
similarity index 72%
rename from server/agent/tools/shell.py
rename to server/agent/tools_factory/shell.py
index db074154..01046559 100644
--- a/server/agent/tools/shell.py
+++ b/server/agent/tools_factory/shell.py
@@ -6,4 +6,4 @@ def shell(query: str):
return tool.run(tool_input=query)
class ShellInput(BaseModel):
- query: str = Field(description="一个能在Linux命令行运行的Shell命令")
\ No newline at end of file
+ query: str = Field(description="The command to execute")
\ No newline at end of file
diff --git a/server/agent/tools_factory/tools_registry.py b/server/agent/tools_factory/tools_registry.py
new file mode 100644
index 00000000..61c8451f
--- /dev/null
+++ b/server/agent/tools_factory/tools_registry.py
@@ -0,0 +1,59 @@
+from langchain_core.tools import StructuredTool
+from server.agent.tools_factory import *
+from configs import KB_INFO
+
+template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool."
+KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
+template_knowledge = template.format(KB_info=KB_info_str)
+
+all_tools = [
+ StructuredTool.from_function(
+ func=calculate,
+ name="calculate",
+ description="Useful for when you need to answer questions about simple calculations",
+ args_schema=CalculatorInput,
+ ),
+ StructuredTool.from_function(
+ func=arxiv,
+ name="arxiv",
+ description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
+ args_schema=ArxivInput,
+ ),
+ StructuredTool.from_function(
+ func=shell,
+ name="shell",
+ description="Use Shell to execute Linux commands",
+ args_schema=ShellInput,
+ ),
+ StructuredTool.from_function(
+ func=wolfram,
+ name="wolfram",
+ description="Useful for when you need to calculate difficult formulas",
+ args_schema=WolframInput,
+ ),
+ StructuredTool.from_function(
+ func=search_youtube,
+ name="search_youtube",
+ description="use this tools_factory to search youtube videos",
+ args_schema=YoutubeInput,
+ ),
+ StructuredTool.from_function(
+ func=weather_check,
+ name="weather_check",
+ description="Use this tool to check the weather at a specific location",
+ args_schema=WeatherInput,
+ ),
+ StructuredTool.from_function(
+ func=search_internet,
+ name="search_internet",
+ description="Use this tool to use bing search engine to search the internet and get information",
+ args_schema=SearchInternetInput,
+ ),
+ StructuredTool.from_function(
+ func=search_local_knowledgebase,
+ name="search_local_knowledgebase",
+ description=template_knowledge,
+ args_schema=SearchKnowledgeInput,
+ ),
+]
+
diff --git a/server/agent/tools/weather_check.py b/server/agent/tools_factory/weather_check.py
similarity index 76%
rename from server/agent/tools/weather_check.py
rename to server/agent/tools_factory/weather_check.py
index 20f009a5..954d74ee 100644
--- a/server/agent/tools/weather_check.py
+++ b/server/agent/tools_factory/weather_check.py
@@ -1,10 +1,8 @@
"""
-更简单的单参数输入工具实现,用于查询现在天气的情况
+简单的单参数输入工具实现,用于查询现在天气的情况
"""
from pydantic import BaseModel, Field
import requests
-from configs.kb_config import SENIVERSE_API_KEY
-
def weather(location: str, api_key: str):
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
@@ -21,9 +19,7 @@ def weather(location: str, api_key: str):
f"Failed to retrieve weather: {response.status_code}")
-def weathercheck(location: str):
- return weather(location, SENIVERSE_API_KEY)
-
-
+def weather_check(location: str):
+ return weather(location, "S8vrB4U_-c5mvAMiK")
class WeatherInput(BaseModel):
- location: str = Field(description="City name,include city and county")
+ location: str = Field(description="City name,include city and county,like '厦门'")
diff --git a/server/agent/tools/wolfram.py b/server/agent/tools_factory/wolfram.py
similarity index 84%
rename from server/agent/tools/wolfram.py
rename to server/agent/tools_factory/wolfram.py
index c322da18..6bdbe10e 100644
--- a/server/agent/tools/wolfram.py
+++ b/server/agent/tools_factory/wolfram.py
@@ -8,4 +8,4 @@ def wolfram(query: str):
return ans
class WolframInput(BaseModel):
- location: str = Field(description="需要运算的具体问题")
\ No newline at end of file
+ formula: str = Field(description="The formula to be calculated")
\ No newline at end of file
diff --git a/server/agent/tools_select.py b/server/agent/tools_select.py
deleted file mode 100644
index 237c20b6..00000000
--- a/server/agent/tools_select.py
+++ /dev/null
@@ -1,55 +0,0 @@
-from langchain.tools import Tool
-from server.agent.tools import *
-
-tools = [
- Tool.from_function(
- func=calculate,
- name="calculate",
- description="Useful for when you need to answer questions about simple calculations",
- args_schema=CalculatorInput,
- ),
- Tool.from_function(
- func=arxiv,
- name="arxiv",
- description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
- args_schema=ArxivInput,
- ),
- Tool.from_function(
- func=weathercheck,
- name="weather_check",
- description="",
- args_schema=WeatherInput,
- ),
- Tool.from_function(
- func=shell,
- name="shell",
- description="Use Shell to execute Linux commands",
- args_schema=ShellInput,
- ),
- Tool.from_function(
- func=search_knowledgebase_complex,
- name="search_knowledgebase_complex",
- description="Use Use this tool to search local knowledgebase and get information",
- args_schema=KnowledgeSearchInput,
- ),
- Tool.from_function(
- func=search_internet,
- name="search_internet",
- description="Use this tool to use bing search engine to search the internet",
- args_schema=SearchInternetInput,
- ),
- Tool.from_function(
- func=wolfram,
- name="Wolfram",
- description="Useful for when you need to calculate difficult formulas",
- args_schema=WolframInput,
- ),
- Tool.from_function(
- func=search_youtube,
- name="search_youtube",
- description="use this tools to search youtube videos",
- args_schema=YoutubeInput,
- ),
-]
-
-tool_names = [tool.name for tool in tools]
diff --git a/server/api.py b/server/api.py
index 6e4aa437..3c4d04df 100644
--- a/server/api.py
+++ b/server/api.py
@@ -13,13 +13,12 @@ from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat.chat import chat
-from server.chat.search_engine_chat import search_engine_chat
from server.chat.completion import completion
from server.chat.feedback import chat_feedback
from server.embeddings_api import embed_texts_endpoint
from server.llm_api import (list_running_models, list_config_models,
change_llm_model, stop_llm_model,
- get_model_config, list_search_engines)
+ get_model_config)
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
get_server_configs, get_prompt_template)
from typing import List, Literal
@@ -63,11 +62,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
summary="与llm模型对话(通过LLMChain)",
)(chat)
- app.post("/chat/search_engine_chat",
- tags=["Chat"],
- summary="与搜索引擎对话",
- )(search_engine_chat)
-
app.post("/chat/feedback",
tags=["Chat"],
summary="返回llm模型对话评分",
@@ -110,16 +104,12 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
summary="获取服务器原始配置信息",
)(get_server_configs)
- app.post("/server/list_search_engines",
- tags=["Server State"],
- summary="获取服务器支持的搜索引擎",
- )(list_search_engines)
@app.post("/server/get_prompt_template",
tags=["Server State"],
summary="获取服务区配置的 prompt 模板")
def get_server_prompt_template(
- type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat,search_engine_chat,agent_chat"),
+ type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat"),
name: str = Body("default", description="模板名称"),
) -> str:
return get_prompt_template(type=type, name=name)
@@ -139,7 +129,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
def mount_knowledge_routes(app: FastAPI):
from server.chat.knowledge_base_chat import knowledge_base_chat
from server.chat.file_chat import upload_temp_docs, file_chat
- from server.chat.agent_chat import agent_chat
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
@@ -155,11 +144,6 @@ def mount_knowledge_routes(app: FastAPI):
summary="文件对话"
)(file_chat)
- app.post("/chat/agent_chat",
- tags=["Chat"],
- summary="与agent对话")(agent_chat)
-
- # Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,
diff --git a/server/agent/callbacks.py b/server/callback_handler/agent_callback_handler.py
similarity index 53%
rename from server/agent/callbacks.py
rename to server/callback_handler/agent_callback_handler.py
index 0935f9dc..dad07a79 100644
--- a/server/agent/callbacks.py
+++ b/server/callback_handler/agent_callback_handler.py
@@ -1,13 +1,11 @@
from __future__ import annotations
from uuid import UUID
-from langchain.callbacks import AsyncIteratorCallbackHandler
import json
-import asyncio
-from typing import Any, Dict, List, Optional
-
from langchain.schema import AgentFinish, AgentAction
-from langchain.schema.output import LLMResult
-
+import asyncio
+from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional
+from langchain_core.outputs import LLMResult
+from langchain.callbacks.base import AsyncCallbackHandler
def dumps(obj: Dict) -> str:
return json.dumps(obj, ensure_ascii=False)
@@ -23,7 +21,7 @@ class Status:
tool_finish: int = 7
-class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
+class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
def __init__(self):
super().__init__()
self.queue = asyncio.Queue()
@@ -31,40 +29,29 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.cur_tool = {}
self.out = True
- async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
- parent_run_id: UUID | None = None, tags: List[str] | None = None,
- metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
-
- # 对于截断不能自理的大模型,我来帮他截断
- stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
- for stop_word in stop_words:
- index = input_str.find(stop_word)
- if index != -1:
- input_str = input_str[:index]
- break
-
- self.cur_tool = {
- "tool_name": serialized["name"],
- "input_str": input_str,
- "output_str": "",
- "status": Status.agent_action,
- "run_id": run_id.hex,
- "llm_token": "",
- "final_answer": "",
- "error": "",
- }
- # print("\nInput Str:",self.cur_tool["input_str"])
- self.queue.put_nowait(dumps(self.cur_tool))
-
- async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
- tags: List[str] | None = None, **kwargs: Any) -> None:
- self.out = True ## 重置输出
- self.cur_tool.update(
- status=Status.tool_finish,
- output_str=output.replace("Answer:", ""),
- )
- self.queue.put_nowait(dumps(self.cur_tool))
+ async def on_tool_start(
+ self,
+ serialized: Dict[str, Any],
+ input_str: str,
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ tags: Optional[List[str]] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
+ ) -> None:
+ print("on_tool_start")
+ async def on_tool_end(
+ self,
+ output: str,
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ tags: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ print("on_tool_end")
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
self.cur_tool.update(
@@ -73,23 +60,6 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
)
self.queue.put_nowait(dumps(self.cur_tool))
- # async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
- # if "Action" in token: ## 减少重复输出
- # before_action = token.split("Action")[0]
- # self.cur_tool.update(
- # status=Status.running,
- # llm_token=before_action + "\n",
- # )
- # self.queue.put_nowait(dumps(self.cur_tool))
- #
- # self.out = False
- #
- # if token and self.out:
- # self.cur_tool.update(
- # status=Status.running,
- # llm_token=token,
- # )
- # self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["Action", "<|observation|>"]
for stoken in special_tokens:
@@ -103,7 +73,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.out = False
break
- if token and self.out:
+ if token is not None and token != "" and self.out:
self.cur_tool.update(
status=Status.running,
llm_token=token,
@@ -116,16 +86,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
+
async def on_chat_model_start(
- self,
- serialized: Dict[str, Any],
- messages: List[List],
- *,
- run_id: UUID,
- parent_run_id: Optional[UUID] = None,
- tags: Optional[List[str]] = None,
- metadata: Optional[Dict[str, Any]] = None,
- **kwargs: Any,
+ self,
+ serialized: Dict[str, Any],
+ messages: List[List],
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ tags: Optional[List[str]] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ **kwargs: Any,
) -> None:
self.cur_tool.update(
status=Status.start,
@@ -136,8 +107,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.complete,
- llm_token="\n",
+ llm_token="",
)
+ self.out = True
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
@@ -147,15 +119,44 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
)
self.queue.put_nowait(dumps(self.cur_tool))
+ async def on_agent_action(
+ self,
+ action: AgentAction,
+ *,
+ run_id: UUID,
+ parent_run_id: Optional[UUID] = None,
+ tags: Optional[List[str]] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.cur_tool.update(
+ status=Status.agent_action,
+ tool_name=action.tool,
+ tool_input=action.tool_input,
+ )
+ self.queue.put_nowait(dumps(self.cur_tool))
async def on_agent_finish(
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
- # 返回最终答案
self.cur_tool.update(
status=Status.agent_finish,
- final_answer=finish.return_values["output"],
+ agent_finish=finish.return_values["output"],
)
self.queue.put_nowait(dumps(self.cur_tool))
- self.cur_tool = {}
+
+ async def aiter(self) -> AsyncIterator[str]:
+ while not self.queue.empty() or not self.done.is_set():
+ done, other = await asyncio.wait(
+ [
+ asyncio.ensure_future(self.queue.get()),
+ asyncio.ensure_future(self.done.wait()),
+ ],
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ if other:
+ other.pop().cancel()
+ token_or_done = cast(Union[str, Literal[True]], done.pop().result())
+ if token_or_done is True:
+ break
+ yield token_or_done
diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py
deleted file mode 100644
index 41bf5bab..00000000
--- a/server/chat/agent_chat.py
+++ /dev/null
@@ -1,178 +0,0 @@
-import json
-import asyncio
-
-from fastapi import Body
-from sse_starlette.sse import EventSourceResponse
-from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
-
-from langchain.chains import LLMChain
-from langchain.memory import ConversationBufferWindowMemory
-from langchain.agents import LLMSingleActionAgent, AgentExecutor
-from typing import AsyncIterable, Optional, List
-
-from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
-from server.knowledge_base.kb_service.base import get_kb_details
-from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
-from server.agent.tools_select import tools, tool_names
-from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
-from server.chat.utils import History
-from server.agent import model_container
-from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
-
-
-async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
- history: List[History] = Body([],
- description="历史对话",
- examples=[[
- {"role": "user", "content": "请使用知识库工具查询今天北京天气"},
- {"role": "assistant",
- "content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]]
- ),
- stream: bool = Body(False, description="流式输出"),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
- max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
- prompt_name: str = Body("default",
- description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
- ):
- history = [History.from_data(h) for h in history]
-
- async def agent_chat_iterator(
- query: str,
- history: Optional[List[History]],
- model_name: str = LLM_MODELS[0],
- prompt_name: str = prompt_name,
- ) -> AsyncIterable[str]:
- nonlocal max_tokens
- callback = CustomAsyncIteratorCallbackHandler()
- if isinstance(max_tokens, int) and max_tokens <= 0:
- max_tokens = None
-
- model = get_ChatOpenAI(
- model_name=model_name,
- temperature=temperature,
- max_tokens=max_tokens,
- callbacks=[callback],
- )
-
- kb_list = {x["kb_name"]: x for x in get_kb_details()}
- model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
-
- if Agent_MODEL:
- model_agent = get_ChatOpenAI(
- model_name=Agent_MODEL,
- temperature=temperature,
- max_tokens=max_tokens,
- callbacks=[callback],
- )
- model_container.MODEL = model_agent
- else:
- model_container.MODEL = model
-
- prompt_template = get_prompt_template("agent_chat", prompt_name)
- prompt_template_agent = CustomPromptTemplate(
- template=prompt_template,
- tools=tools,
- input_variables=["input", "intermediate_steps", "history"]
- )
- output_parser = CustomOutputParser()
- llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
- memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
- for message in history:
- if message.role == 'user':
- memory.chat_memory.add_user_message(message.content)
- else:
- memory.chat_memory.add_ai_message(message.content)
- if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name:
- agent_executor = initialize_glm3_agent(
- llm=model,
- tools=tools,
- callback_manager=None,
- prompt=prompt_template,
- input_variables=["input", "intermediate_steps", "history"],
- memory=memory,
- verbose=True,
- )
- else:
- agent = LLMSingleActionAgent(
- llm_chain=llm_chain,
- output_parser=output_parser,
- stop=["\nObservation:", "Observation"],
- allowed_tools=tool_names,
- )
- agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
- tools=tools,
- verbose=True,
- memory=memory,
- )
- while True:
- try:
- task = asyncio.create_task(wrap_done(
- agent_executor.acall(query, callbacks=[callback], include_run_info=True),
- callback.done))
- break
- except:
- pass
-
- if stream:
- async for chunk in callback.aiter():
- tools_use = []
- # Use server-sent-events to stream the response
- data = json.loads(chunk)
- if data["status"] == Status.start or data["status"] == Status.complete:
- continue
- elif data["status"] == Status.error:
- tools_use.append("\n```\n")
- tools_use.append("工具名称: " + data["tool_name"])
- tools_use.append("工具状态: " + "调用失败")
- tools_use.append("错误信息: " + data["error"])
- tools_use.append("重新开始尝试")
- tools_use.append("\n```\n")
- yield json.dumps({"tools": tools_use}, ensure_ascii=False)
- elif data["status"] == Status.tool_finish:
- tools_use.append("\n```\n")
- tools_use.append("工具名称: " + data["tool_name"])
- tools_use.append("工具状态: " + "调用成功")
- tools_use.append("工具输入: " + data["input_str"])
- tools_use.append("工具输出: " + data["output_str"])
- tools_use.append("\n```\n")
- yield json.dumps({"tools": tools_use}, ensure_ascii=False)
- elif data["status"] == Status.agent_finish:
- yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
- else:
- yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
-
-
- else:
- answer = ""
- final_answer = ""
- async for chunk in callback.aiter():
- data = json.loads(chunk)
- if data["status"] == Status.start or data["status"] == Status.complete:
- continue
- if data["status"] == Status.error:
- answer += "\n```\n"
- answer += "工具名称: " + data["tool_name"] + "\n"
- answer += "工具状态: " + "调用失败" + "\n"
- answer += "错误信息: " + data["error"] + "\n"
- answer += "\n```\n"
- if data["status"] == Status.tool_finish:
- answer += "\n```\n"
- answer += "工具名称: " + data["tool_name"] + "\n"
- answer += "工具状态: " + "调用成功" + "\n"
- answer += "工具输入: " + data["input_str"] + "\n"
- answer += "工具输出: " + data["output_str"] + "\n"
- answer += "\n```\n"
- if data["status"] == Status.agent_finish:
- final_answer = data["final_answer"]
- else:
- answer += data["llm_token"]
-
- yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
- await task
-
- return EventSourceResponse(agent_chat_iterator(query=query,
- history=history,
- model_name=model_name,
- prompt_name=prompt_name),
- )
diff --git a/server/chat/chat.py b/server/chat/chat.py
index 5783829c..ccf0982b 100644
--- a/server/chat/chat.py
+++ b/server/chat/chat.py
@@ -1,20 +1,43 @@
from fastapi import Body
-from sse_starlette.sse import EventSourceResponse
-from configs import LLM_MODELS, TEMPERATURE
+from fastapi.responses import StreamingResponse
+from langchain.agents import initialize_agent, AgentType
+from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.runnables import RunnableBranch
+from server.agent.agent_factory import initialize_glm3_agent
+from server.agent.tools_factory.tools_registry import all_tools
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
-from langchain.callbacks import AsyncIteratorCallbackHandler
-from typing import AsyncIterable
+from typing import AsyncIterable, Dict
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
-from typing import List, Optional, Union
+from typing import List, Union
from server.chat.utils import History
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
-from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
+from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
+
+
+def create_models_from_config(configs: dict = {}, callbacks: list = []):
+ models = {}
+ prompts = {}
+ for model_type, model_configs in configs.items():
+ for model_name, params in model_configs.items():
+ callback = callbacks if params.get('callbacks', False) else None
+ model_instance = get_ChatOpenAI(
+ model_name=model_name,
+ temperature=params.get('temperature', 0.5),
+ max_tokens=params.get('max_tokens', 1000),
+ callbacks=callback
+ )
+ models[model_type] = model_instance
+ prompt_name = params.get('prompt_name', 'default')
+ prompt_template = get_prompt_template(type=model_type, name=prompt_name)
+ prompts[model_type] = prompt_template
+ return models, prompts
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
@@ -28,76 +51,106 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
- max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
- # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
- prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
+ model_config: Dict = Body({}, description="LLM 模型配置。"),
+ tool_config: Dict = Body({}, description="工具配置"),
):
async def chat_iterator() -> AsyncIterable[str]:
- nonlocal history, max_tokens
- callback = AsyncIteratorCallbackHandler()
- callbacks = [callback]
+ nonlocal history
memory = None
+ message_id = None
+ chat_prompt = None
- # 负责保存llm response到message db
- message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
- conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
- chat_type="llm_chat",
- query=query)
- callbacks.append(conversation_callback)
+ callback = CustomAsyncIteratorCallbackHandler()
+ callbacks = [callback]
+ models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
- if isinstance(max_tokens, int) and max_tokens <= 0:
- max_tokens = None
+ if conversation_id:
+ message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
+ tools = [tool for tool in all_tools if tool.name in tool_config]
- model = get_ChatOpenAI(
- model_name=model_name,
- temperature=temperature,
- max_tokens=max_tokens,
- callbacks=callbacks,
- )
-
- if history: # 优先使用前端传入的历史消息
+ if history:
history = [History.from_data(h) for h in history]
- prompt_template = get_prompt_template("llm_chat", prompt_name)
- input_msg = History(role="user", content=prompt_template).to_msg_template(False)
+ input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
- elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
- # 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
- prompt = get_prompt_template("llm_chat", "with_history")
- chat_prompt = PromptTemplate.from_template(prompt)
- # 根据conversation_id 获取message 列表进而拼凑 memory
- memory = ConversationBufferDBMemory(conversation_id=conversation_id,
- llm=model,
+ elif conversation_id and history_len > 0:
+ memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
message_limit=history_len)
else:
- prompt_template = get_prompt_template("llm_chat", prompt_name)
- input_msg = History(role="user", content=prompt_template).to_msg_template(False)
+ input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
- chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
-
- # Begin a task that runs in the background.
- task = asyncio.create_task(wrap_done(
- chain.acall({"input": query}),
- callback.done),
+ chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
+ classifier_chain = (
+ PromptTemplate.from_template(prompts["preprocess_model"])
+ | models["preprocess_model"]
+ | StrOutputParser()
)
-
- if stream:
- async for token in callback.aiter():
- # Use server-sent-events to stream the response
- yield json.dumps(
- {"text": token, "message_id": message_id},
- ensure_ascii=False)
+ if "chatglm3" in models["action_model"].model_name.lower():
+ agent_executor = initialize_glm3_agent(
+ llm=models["action_model"],
+ tools=tools,
+ prompt=prompts["action_model"],
+ input_variables=["input", "intermediate_steps", "history"],
+ memory=memory,
+ callback_manager=BaseCallbackManager(handlers=callbacks),
+ verbose=True,
+ )
else:
- answer = ""
- async for token in callback.aiter():
- answer += token
- yield json.dumps(
- {"text": answer, "message_id": message_id},
- ensure_ascii=False)
-
+ agent_executor = initialize_agent(
+ llm=models["action_model"],
+ tools=tools,
+ callbacks=callbacks,
+ agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
+ memory=memory,
+ verbose=True,
+ )
+ branch = RunnableBranch(
+ (lambda x: "1" in x["topic"].lower(), agent_executor),
+ chain
+ )
+ full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
+ task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
+ if stream:
+ async for chunk in callback.aiter():
+ data = json.loads(chunk)
+ if data["status"] == Status.start:
+ continue
+ elif data["status"] == Status.agent_action:
+ tool_info = {
+ "tool_name": data["tool_name"],
+ "tool_input": data["tool_input"]
+ }
+ yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False)
+ elif data["status"] == Status.agent_finish:
+ yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id},
+ ensure_ascii=False)
+ else:
+ yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False)
+ else:
+ text = ""
+ agent_finish = ""
+ tool_info = None
+ async for chunk in callback.aiter():
+ # Use server-sent-events to stream the response
+ data = json.loads(chunk)
+ if data["status"] == Status.agent_action:
+ tool_info = {
+ "tool_name": data["tool_name"],
+ "tool_input": data["tool_input"]
+ }
+ if data["status"] == Status.agent_finish:
+ agent_finish = data["agent_finish"]
+ else:
+ text += data["llm_token"]
+ if tool_info:
+ yield json.dumps(
+ {"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id},
+ ensure_ascii=False)
+ else:
+ yield json.dumps(
+ {"text": text, "message_id": message_id},
+ ensure_ascii=False)
await task
return EventSourceResponse(chat_iterator())
diff --git a/server/chat/completion.py b/server/chat/completion.py
index acf17366..f93abce8 100644
--- a/server/chat/completion.py
+++ b/server/chat/completion.py
@@ -1,6 +1,6 @@
from fastapi import Body
-from sse_starlette.sse import EventSourceResponse
-from configs import LLM_MODELS, TEMPERATURE
+from fastapi.responses import StreamingResponse
+from configs import LLM_MODEL_CONFIG
from server.utils import wrap_done, get_OpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
@@ -14,8 +14,8 @@ from server.utils import get_prompt_template
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
stream: bool = Body(False, description="流式输出"),
echo: bool = Body(False, description="除了输出之外,还回显输入"),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ model_name: str = Body(None, description="LLM 模型名称。"),
+ temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default",
@@ -24,7 +24,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
#todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str,
- model_name: str = LLM_MODELS[0],
+ model_name: str = None,
prompt_name: str = prompt_name,
echo: bool = echo,
) -> AsyncIterable[str]:
diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py
index a3bbdfc6..275371b5 100644
--- a/server/chat/file_chat.py
+++ b/server/chat/file_chat.py
@@ -1,7 +1,6 @@
from fastapi import Body, File, Form, UploadFile
-from sse_starlette.sse import EventSourceResponse
-from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
- CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
+from fastapi.responses import StreamingResponse
+from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import (wrap_done, get_ChatOpenAI,
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
@@ -15,8 +14,6 @@ from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.knowledge_base.utils import KnowledgeFile
import json
import os
-from pathlib import Path
-
def _parse_files_in_thread(
files: List[UploadFile],
@@ -102,8 +99,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ model_name: str = Body(None, description="LLM 模型名称。"),
+ temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py
deleted file mode 100644
index 60956b44..00000000
--- a/server/chat/knowledge_base_chat.py
+++ /dev/null
@@ -1,147 +0,0 @@
-from fastapi import Body, Request
-from sse_starlette.sse import EventSourceResponse
-from fastapi.concurrency import run_in_threadpool
-from configs import (LLM_MODELS,
- VECTOR_SEARCH_TOP_K,
- SCORE_THRESHOLD,
- TEMPERATURE,
- USE_RERANKER,
- RERANKER_MODEL,
- RERANKER_MAX_LENGTH,
- MODEL_PATH)
-from server.utils import wrap_done, get_ChatOpenAI
-from server.utils import BaseResponse, get_prompt_template
-from langchain.chains import LLMChain
-from langchain.callbacks import AsyncIteratorCallbackHandler
-from typing import AsyncIterable, List, Optional
-import asyncio
-from langchain.prompts.chat import ChatPromptTemplate
-from server.chat.utils import History
-from server.knowledge_base.kb_service.base import KBServiceFactory
-import json
-from urllib.parse import urlencode
-from server.knowledge_base.kb_doc_api import search_docs
-from server.reranker.reranker import LangchainReranker
-from server.utils import embedding_device
-async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
- knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
- top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
- score_threshold: float = Body(
- SCORE_THRESHOLD,
- description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
- ge=0,
- le=2
- ),
- history: List[History] = Body(
- [],
- description="历史对话",
- examples=[[
- {"role": "user",
- "content": "我们来玩成语接龙,我先来,生龙活虎"},
- {"role": "assistant",
- "content": "虎头虎脑"}]]
- ),
- stream: bool = Body(False, description="流式输出"),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
- max_tokens: Optional[int] = Body(
- None,
- description="限制LLM生成Token数量,默认None代表模型最大值"
- ),
- prompt_name: str = Body(
- "default",
- description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
- ),
- request: Request = None,
- ):
- kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
- if kb is None:
- return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
-
- history = [History.from_data(h) for h in history]
-
- async def knowledge_base_chat_iterator(
- query: str,
- top_k: int,
- history: Optional[List[History]],
- model_name: str = model_name,
- prompt_name: str = prompt_name,
- ) -> AsyncIterable[str]:
- nonlocal max_tokens
- callback = AsyncIteratorCallbackHandler()
- if isinstance(max_tokens, int) and max_tokens <= 0:
- max_tokens = None
-
- model = get_ChatOpenAI(
- model_name=model_name,
- temperature=temperature,
- max_tokens=max_tokens,
- callbacks=[callback],
- )
- docs = await run_in_threadpool(search_docs,
- query=query,
- knowledge_base_name=knowledge_base_name,
- top_k=top_k,
- score_threshold=score_threshold)
-
- # 加入reranker
- if USE_RERANKER:
- reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
- print("-----------------model path------------------")
- print(reranker_model_path)
- reranker_model = LangchainReranker(top_n=top_k,
- device=embedding_device(),
- max_length=RERANKER_MAX_LENGTH,
- model_name_or_path=reranker_model_path
- )
- print(docs)
- docs = reranker_model.compress_documents(documents=docs,
- query=query)
- print("---------after rerank------------------")
- print(docs)
- context = "\n".join([doc.page_content for doc in docs])
-
- if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
- prompt_template = get_prompt_template("knowledge_base_chat", "empty")
- else:
- prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
- input_msg = History(role="user", content=prompt_template).to_msg_template(False)
- chat_prompt = ChatPromptTemplate.from_messages(
- [i.to_msg_template() for i in history] + [input_msg])
-
- chain = LLMChain(prompt=chat_prompt, llm=model)
-
- # Begin a task that runs in the background.
- task = asyncio.create_task(wrap_done(
- chain.acall({"context": context, "question": query}),
- callback.done),
- )
-
- source_documents = []
- for inum, doc in enumerate(docs):
- filename = doc.metadata.get("source")
- parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
- base_url = request.base_url
- url = f"{base_url}knowledge_base/download_doc?" + parameters
- text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
- source_documents.append(text)
-
- if len(source_documents) == 0: # 没有找到相关文档
- source_documents.append(f"未找到相关文档,该回答为大模型自身能力解答!")
-
- if stream:
- async for token in callback.aiter():
- # Use server-sent-events to stream the response
- yield json.dumps({"answer": token}, ensure_ascii=False)
- yield json.dumps({"docs": source_documents}, ensure_ascii=False)
- else:
- answer = ""
- async for token in callback.aiter():
- answer += token
- yield json.dumps({"answer": answer,
- "docs": source_documents},
- ensure_ascii=False)
- await task
-
- return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))
-
diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py
deleted file mode 100644
index 42bef3c2..00000000
--- a/server/chat/search_engine_chat.py
+++ /dev/null
@@ -1,208 +0,0 @@
-from langchain.utilities.bing_search import BingSearchAPIWrapper
-from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
-from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
- LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
-from langchain.chains import LLMChain
-from langchain.callbacks import AsyncIteratorCallbackHandler
-
-from langchain.prompts.chat import ChatPromptTemplate
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain.docstore.document import Document
-from fastapi import Body
-from fastapi.concurrency import run_in_threadpool
-from sse_starlette import EventSourceResponse
-from server.utils import wrap_done, get_ChatOpenAI
-from server.utils import BaseResponse, get_prompt_template
-from server.chat.utils import History
-from typing import AsyncIterable
-import asyncio
-import json
-from typing import List, Optional, Dict
-from strsimpy.normalized_levenshtein import NormalizedLevenshtein
-from markdownify import markdownify
-
-
-def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
- if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
- return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
- "title": "env info is not found",
- "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
- search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
- bing_search_url=BING_SEARCH_URL)
- return search.results(text, result_len)
-
-
-def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
- search = DuckDuckGoSearchAPIWrapper()
- return search.results(text, result_len)
-
-
-def metaphor_search(
- text: str,
- result_len: int = SEARCH_ENGINE_TOP_K,
- split_result: bool = False,
- chunk_size: int = 500,
- chunk_overlap: int = OVERLAP_SIZE,
-) -> List[Dict]:
- from metaphor_python import Metaphor
-
- if not METAPHOR_API_KEY:
- return []
-
- client = Metaphor(METAPHOR_API_KEY)
- search = client.search(text, num_results=result_len, use_autoprompt=True)
- contents = search.get_contents().contents
- for x in contents:
- x.extract = markdownify(x.extract)
-
- # metaphor 返回的内容都是长文本,需要分词再检索
- if split_result:
- docs = [Document(page_content=x.extract,
- metadata={"link": x.url, "title": x.title})
- for x in contents]
- text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
- chunk_size=chunk_size,
- chunk_overlap=chunk_overlap)
- splitted_docs = text_splitter.split_documents(docs)
-
- # 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
- if len(splitted_docs) > result_len:
- normal = NormalizedLevenshtein()
- for x in splitted_docs:
- x.metadata["score"] = normal.similarity(text, x.page_content)
- splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
- splitted_docs = splitted_docs[:result_len]
-
- docs = [{"snippet": x.page_content,
- "link": x.metadata["link"],
- "title": x.metadata["title"]}
- for x in splitted_docs]
- else:
- docs = [{"snippet": x.extract,
- "link": x.url,
- "title": x.title}
- for x in contents]
-
- return docs
-
-
-SEARCH_ENGINES = {"bing": bing_search,
- "duckduckgo": duckduckgo_search,
- "metaphor": metaphor_search,
- }
-
-
-def search_result2docs(search_results):
- docs = []
- for result in search_results:
- doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
- metadata={"source": result["link"] if "link" in result.keys() else "",
- "filename": result["title"] if "title" in result.keys() else ""})
- docs.append(doc)
- return docs
-
-
-async def lookup_search_engine(
- query: str,
- search_engine_name: str,
- top_k: int = SEARCH_ENGINE_TOP_K,
- split_result: bool = False,
-):
- search_engine = SEARCH_ENGINES[search_engine_name]
- results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result)
- docs = search_result2docs(results)
- return docs
-
-async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
- search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
- top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
- history: List[History] = Body([],
- description="历史对话",
- examples=[[
- {"role": "user",
- "content": "我们来玩成语接龙,我先来,生龙活虎"},
- {"role": "assistant",
- "content": "虎头虎脑"}]]
- ),
- stream: bool = Body(False, description="流式输出"),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
- max_tokens: Optional[int] = Body(None,
- description="限制LLM生成Token数量,默认None代表模型最大值"),
- prompt_name: str = Body("default",
- description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
- split_result: bool = Body(False,
- description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
- ):
- if search_engine_name not in SEARCH_ENGINES.keys():
- return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
-
- if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
- return BaseResponse(code=404, msg=f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`")
-
- history = [History.from_data(h) for h in history]
-
- async def search_engine_chat_iterator(query: str,
- search_engine_name: str,
- top_k: int,
- history: Optional[List[History]],
- model_name: str = LLM_MODELS[0],
- prompt_name: str = prompt_name,
- ) -> AsyncIterable[str]:
- nonlocal max_tokens
- callback = AsyncIteratorCallbackHandler()
- if isinstance(max_tokens, int) and max_tokens <= 0:
- max_tokens = None
-
- model = get_ChatOpenAI(
- model_name=model_name,
- temperature=temperature,
- max_tokens=max_tokens,
- callbacks=[callback],
- )
-
- docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result)
- context = "\n".join([doc.page_content for doc in docs])
-
- prompt_template = get_prompt_template("search_engine_chat", prompt_name)
- input_msg = History(role="user", content=prompt_template).to_msg_template(False)
- chat_prompt = ChatPromptTemplate.from_messages(
- [i.to_msg_template() for i in history] + [input_msg])
-
- chain = LLMChain(prompt=chat_prompt, llm=model)
-
- # Begin a task that runs in the background.
- task = asyncio.create_task(wrap_done(
- chain.acall({"context": context, "question": query}),
- callback.done),
- )
-
- source_documents = [
- f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
- for inum, doc in enumerate(docs)
- ]
-
- if len(source_documents) == 0: # 没有找到相关资料(不太可能)
- source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""")
-
- if stream:
- async for token in callback.aiter():
- # Use server-sent-events to stream the response
- yield json.dumps({"answer": token}, ensure_ascii=False)
- yield json.dumps({"docs": source_documents}, ensure_ascii=False)
- else:
- answer = ""
- async for token in callback.aiter():
- answer += token
- yield json.dumps({"answer": answer,
- "docs": source_documents},
- ensure_ascii=False)
- await task
-
- return EventSourceResponse(search_engine_chat_iterator(query=query,
- search_engine_name=search_engine_name,
- top_k=top_k,
- history=history,
- model_name=model_name,
- prompt_name=prompt_name),
- )
diff --git a/server/db/models/conversation_model.py b/server/db/models/conversation_model.py
index 9cc6d5b6..c9a53bbc 100644
--- a/server/db/models/conversation_model.py
+++ b/server/db/models/conversation_model.py
@@ -9,7 +9,6 @@ class ConversationModel(Base):
__tablename__ = 'conversation'
id = Column(String(32), primary_key=True, comment='对话框ID')
name = Column(String(50), comment='对话框名称')
- # chat/agent_chat等
chat_type = Column(String(50), comment='聊天类型')
create_time = Column(DateTime, default=func.now(), comment='创建时间')
diff --git a/server/db/models/message_model.py b/server/db/models/message_model.py
index 7b76df19..de0bc340 100644
--- a/server/db/models/message_model.py
+++ b/server/db/models/message_model.py
@@ -10,7 +10,6 @@ class MessageModel(Base):
__tablename__ = 'message'
id = Column(String(32), primary_key=True, comment='聊天记录ID')
conversation_id = Column(String(32), default=None, index=True, comment='对话框ID')
- # chat/agent_chat等
chat_type = Column(String(50), comment='聊天类型')
query = Column(String(4096), comment='用户问题')
response = Column(String(4096), comment='模型回答')
diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py
index 6558f877..d0d49280 100644
--- a/server/knowledge_base/kb_summary_api.py
+++ b/server/knowledge_base/kb_summary_api.py
@@ -10,7 +10,6 @@ from typing import List, Optional
from server.knowledge_base.kb_summary.base import KBSummaryService
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
-from configs import LLM_MODELS, TEMPERATURE
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
def recreate_summary_vector_store(
@@ -19,8 +18,8 @@ def recreate_summary_vector_store(
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ model_name: str = Body(None, description="LLM 模型名称。"),
+ temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
):
"""
@@ -100,8 +99,8 @@ def summary_file_to_vector_store(
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ model_name: str = Body(None, description="LLM 模型名称。"),
+ temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
):
"""
@@ -172,8 +171,8 @@ def summary_doc_ids_to_vector_store(
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
- model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
- temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ model_name: str = Body(None, description="LLM 模型名称。"),
+ temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
) -> BaseResponse:
"""
diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py
index f2ddbfd0..61b3625c 100644
--- a/server/knowledge_base/utils.py
+++ b/server/knowledge_base/utils.py
@@ -7,7 +7,6 @@ from configs import (
logger,
log_verbose,
text_splitter_dict,
- LLM_MODELS,
TEXT_SPLITTER_NAME,
)
import importlib
@@ -187,10 +186,10 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
def make_text_splitter(
- splitter_name: str = TEXT_SPLITTER_NAME,
- chunk_size: int = CHUNK_SIZE,
- chunk_overlap: int = OVERLAP_SIZE,
- llm_model: str = LLM_MODELS[0],
+ splitter_name,
+ chunk_size,
+ chunk_overlap,
+ llm_model,
):
"""
根据参数获取特定的分词器
diff --git a/server/llm_api.py b/server/llm_api.py
index fbac4937..d642eeac 100644
--- a/server/llm_api.py
+++ b/server/llm_api.py
@@ -1,5 +1,5 @@
from fastapi import Body
-from configs import logger, log_verbose, LLM_MODELS, HTTPX_DEFAULT_TIMEOUT
+from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
get_httpx_client, get_model_worker_config)
from typing import List
@@ -62,7 +62,7 @@ def get_model_config(
def stop_llm_model(
- model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODELS[0]]),
+ model_name: str = Body(..., description="要停止的LLM模型名称", examples=[]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
@@ -86,8 +86,8 @@ def stop_llm_model(
def change_llm_model(
- model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODELS[0]]),
- new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODELS[0]]),
+ model_name: str = Body(..., description="当前运行模型"),
+ new_model_name: str = Body(..., description="要切换的新模型"),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
@@ -108,9 +108,3 @@ def change_llm_model(
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
-
-
-def list_search_engines() -> BaseResponse:
- from server.chat.search_engine_chat import SEARCH_ENGINES
-
- return BaseResponse(data=list(SEARCH_ENGINES))
diff --git a/server/model_workers/base.py b/server/model_workers/base.py
index 234ab47a..b6e88d31 100644
--- a/server/model_workers/base.py
+++ b/server/model_workers/base.py
@@ -1,5 +1,5 @@
from fastchat.conversation import Conversation
-from configs import LOG_PATH, TEMPERATURE
+from configs import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.base_model_worker import BaseModelWorker
@@ -63,7 +63,7 @@ class ApiModelParams(ApiConfigParams):
deployment_name: Optional[str] = None # for azure
resource_name: Optional[str] = None # for azure
- temperature: float = TEMPERATURE
+ temperature: float = 0.9
max_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py
index 2741b74d..f9ae6cb2 100644
--- a/server/model_workers/qwen.py
+++ b/server/model_workers/qwen.py
@@ -1,11 +1,6 @@
-import json
import sys
-
from fastchat.conversation import Conversation
-from configs import TEMPERATURE
-from http import HTTPStatus
from typing import List, Literal, Dict
-
from fastchat import conversation as conv
from server.model_workers.base import *
from server.model_workers.base import ApiEmbeddingsParams
diff --git a/server/utils.py b/server/utils.py
index 7fed5f8c..dac08a04 100644
--- a/server/utils.py
+++ b/server/utils.py
@@ -4,7 +4,7 @@ from typing import List
from fastapi import FastAPI
from pathlib import Path
import asyncio
-from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
+from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os
@@ -402,8 +402,8 @@ def fschat_controller_address() -> str:
return f"http://{host}:{port}"
-def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
- if model := get_model_worker_config(model_name):
+def fschat_model_worker_address(model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model']))) -> str:
+ if model := get_model_worker_config(model_name): # TODO: depends fastchat
host = model["host"]
if host == "0.0.0.0":
host = "127.0.0.1"
@@ -443,7 +443,7 @@ def webui_address() -> str:
def get_prompt_template(type: str, name: str) -> Optional[str]:
'''
从prompt_config中加载模板内容
- type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
+ type: "llm_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
'''
from configs import prompt_config
@@ -617,26 +617,6 @@ def get_server_configs() -> Dict:
'''
获取configs中的原始配置项,供前端使用
'''
- from configs.kb_config import (
- DEFAULT_KNOWLEDGE_BASE,
- DEFAULT_SEARCH_ENGINE,
- DEFAULT_VS_TYPE,
- CHUNK_SIZE,
- OVERLAP_SIZE,
- SCORE_THRESHOLD,
- VECTOR_SEARCH_TOP_K,
- SEARCH_ENGINE_TOP_K,
- ZH_TITLE_ENHANCE,
- text_splitter_dict,
- TEXT_SPLITTER_NAME,
- )
- from configs.model_config import (
- LLM_MODELS,
- HISTORY_LEN,
- TEMPERATURE,
- )
- from configs.prompt_config import PROMPT_TEMPLATES
-
_custom = {
"controller_address": fschat_controller_address(),
"openai_api_address": fschat_openai_api_address(),
diff --git a/startup.py b/startup.py
index 70b5eccc..60e4b071 100644
--- a/startup.py
+++ b/startup.py
@@ -8,6 +8,7 @@ from datetime import datetime
from pprint import pprint
from langchain_core._api import deprecated
+# 设置numexpr最大线程数,默认为CPU核心数
try:
import numexpr
@@ -21,7 +22,7 @@ from configs import (
LOG_PATH,
log_verbose,
logger,
- LLM_MODELS,
+ LLM_MODEL_CONFIG,
EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
FSCHAT_CONTROLLER,
@@ -32,13 +33,21 @@ from configs import (
HTTPX_DEFAULT_TIMEOUT,
)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
- fschat_openai_api_address, get_httpx_client, get_model_worker_config,
+ fschat_openai_api_address, get_httpx_client,
+ get_model_worker_config,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
from server.knowledge_base.migrate import create_tables
import argparse
from typing import List, Dict
from configs import VERSION
+all_model_names = set()
+for model_category in LLM_MODEL_CONFIG.values():
+ for model_name in model_category.keys():
+ if model_name not in all_model_names:
+ all_model_names.add(model_name)
+
+all_model_names_list = list(all_model_names)
@deprecated(
since="0.3.0",
@@ -109,9 +118,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
from vllm import AsyncLLMEngine
- from vllm.engine.arg_utils import AsyncEngineArgs
+ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
- args.tokenizer = args.model_path
+ args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer_mode = 'auto'
args.trust_remote_code = True
args.download_dir = None
@@ -130,7 +139,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.conv_template = None
args.limit_worker_concurrency = 5
args.no_register = False
- args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
+ args.num_gpus = 4 # vllm worker的切分是tensor并行,这里填写显卡的数量
args.engine_use_ray = False
args.disable_log_requests = False
@@ -365,7 +374,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
def run_model_worker(
- model_name: str = LLM_MODELS[0],
+ model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model'])),
controller_address: str = "",
log_level: str = "INFO",
q: mp.Queue = None,
@@ -502,7 +511,7 @@ def parse_args() -> argparse.ArgumentParser:
"--model-worker",
action="store_true",
help="run fastchat's model_worker server with specified model name. "
- "specify --model-name if not using default LLM_MODELS",
+ "specify --model-name if not using default llm models",
dest="model_worker",
)
parser.add_argument(
@@ -510,7 +519,7 @@ def parse_args() -> argparse.ArgumentParser:
"--model-name",
type=str,
nargs="+",
- default=LLM_MODELS,
+ default=all_model_names_list,
help="specify model name for model worker. "
"add addition names with space seperated to start multiple model workers.",
dest="model_name",
@@ -574,7 +583,7 @@ def dump_server_info(after_start=False, args=None):
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
print("\n")
- models = LLM_MODELS
+ models = list(LLM_MODEL_CONFIG['llm_model'].keys())
if args and args.model_name:
models = args.model_name
@@ -769,17 +778,17 @@ async def start_main_server():
if p := processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
- api_started.wait()
+ api_started.wait() # 等待api.py启动完成
if p := processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
- webui_started.wait()
+ webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args)
while True:
- cmd = queue.get()
+ cmd = queue.get() # 收到切换模型的消息
e = manager.Event()
if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd
@@ -877,20 +886,5 @@ if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
-
loop.run_until_complete(start_main_server())
-# 服务启动后接口调用示例:
-# import openai
-# openai.api_key = "EMPTY" # Not support yet
-# openai.api_base = "http://localhost:8888/v1"
-
-# model = "chatglm3-6b"
-
-# # create a chat completion
-# completion = openai.ChatCompletion.create(
-# model=model,
-# messages=[{"role": "user", "content": "Hello! What is your name?"}]
-# )
-# # print the completion
-# print(completion.choices[0].message.content)
diff --git a/tests/api/test_server_state_api.py b/tests/api/test_server_state_api.py
index 59c0985f..2edfb496 100644
--- a/tests/api/test_server_state_api.py
+++ b/tests/api/test_server_state_api.py
@@ -29,15 +29,7 @@ def test_server_configs():
assert len(configs) > 0
-def test_list_search_engines():
- engines = api.list_search_engines()
- pprint(engines)
-
- assert isinstance(engines, list)
- assert len(engines) > 0
-
-
-@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
+@pytest.mark.parametrize("type", ["llm_chat"])
def test_get_prompt_template(type):
print(f"prompt template for: {type}")
template = api.get_prompt_template(type=type)
diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py
index bf7b0571..daacc560 100644
--- a/tests/api/test_stream_chat_api.py
+++ b/tests/api/test_stream_chat_api.py
@@ -85,29 +85,3 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
assert "docs" in data and len(data["docs"]) > 0
assert response.status_code == 200
-
-def test_search_engine_chat(api="/chat/search_engine_chat"):
- global data
-
- data["query"] = "室温超导最新进展是什么样?"
-
- url = f"{api_base_url}{api}"
- for se in ["bing", "duckduckgo"]:
- data["search_engine_name"] = se
- dump_input(data, api + f" by {se}")
- response = requests.post(url, json=data, stream=True)
- if se == "bing" and not BING_SUBSCRIPTION_KEY:
- data = response.json()
- assert data["code"] == 404
- assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
-
- print("\n")
- print("=" * 30 + api + f" by {se} output" + "="*30)
- for line in response.iter_content(None, decode_unicode=True):
- data = json.loads(line[6:])
- if "answer" in data:
- print(data["answer"], end="", flush=True)
- assert "docs" in data and len(data["docs"]) > 0
- pprint(data["docs"])
- assert response.status_code == 200
-
diff --git a/tests/test_online_api.py b/tests/test_online_api.py
index b33d1344..372fad4c 100644
--- a/tests/test_online_api.py
+++ b/tests/test_online_api.py
@@ -56,15 +56,4 @@ def test_embeddings(worker):
assert isinstance(embeddings, list) and len(embeddings) > 0
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
assert isinstance(embeddings[0][0], float)
- print("向量长度:", len(embeddings[0]))
-
-
-# @pytest.mark.parametrize("worker", workers)
-# def test_completion(worker):
-# params = ApiCompletionParams(prompt="五十六个民族")
-
-# print(f"\completion with {worker} \n")
-
-# worker_class = get_model_worker_config(worker)["worker_class"]
-# resp = worker_class().do_completion(params)
-# pprint(resp)
+ print("向量长度:", len(embeddings[0]))
\ No newline at end of file
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index bf4a6e69..be4934c7 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -6,8 +6,7 @@ from datetime import datetime
import os
import re
import time
-from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, LLM_MODELS,
- DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
+from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
from server.knowledge_base.utils import LOADER_DICT
import uuid
from typing import List, Dict
@@ -55,6 +54,7 @@ def parse_command(text: str, modal: Modal) -> bool:
/new {session_name}。如果未提供名称,默认为“会话X”
/del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。
/clear {session_name}。如果未提供名称,默认清除当前会话
+ /stop {session_name}。如果未提供名称,默认停止当前会话
/help。查看命令帮助
返回值:输入的是命令返回True,否则返回False
'''
@@ -117,36 +117,38 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.write("\n\n".join(cmds))
with st.sidebar:
- # 多会话
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
+
+ tools = list(TOOL_CONFIG.keys())
+ selected_tool_configs = {}
+
+ with st.expander("工具栏"):
+ for tool in tools:
+ is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool)
+ if is_selected:
+ selected_tool_configs[tool] = TOOL_CONFIG[tool]
+
if st.session_state.get("cur_conv_name") in conv_names:
index = conv_names.index(st.session_state.get("cur_conv_name"))
- conversation_name = st.selectbox("当前会话:", conv_names, index=index)
+ conversation_name = st.selectbox("当前会话", conv_names, index=index)
chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name]
- def on_mode_change():
- mode = st.session_state.dialogue_mode
- text = f"已切换到 {mode} 模式。"
- if mode == "知识库问答":
- cur_kb = st.session_state.get("selected_kb")
- if cur_kb:
- text = f"{text} 当前知识库: `{cur_kb}`。"
- st.toast(text)
+ # def on_mode_change():
+ # mode = st.session_state.dialogue_mode
+ # text = f"已切换到 {mode} 模式。"
+ # st.toast(text)
- dialogue_modes = ["LLM 对话",
- "知识库问答",
- "文件对话",
- "搜索引擎问答",
- "自定义Agent问答",
- ]
- dialogue_mode = st.selectbox("请选择对话模式:",
- dialogue_modes,
- index=0,
- on_change=on_mode_change,
- key="dialogue_mode",
- )
+ # dialogue_modes = ["智能对话",
+ # "文件对话",
+ # ]
+ # dialogue_mode = st.selectbox("请选择对话模式:",
+ # dialogue_modes,
+ # index=0,
+ # on_change=on_mode_change,
+ # key="dialogue_mode",
+ # )
def on_llm_change():
if llm_model:
@@ -164,7 +166,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
available_models = []
config_models = api.list_config_models()
if not is_lite:
- for k, v in config_models.get("local", {}).items():
+ for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
if (v.get("model_path_exists")
and k not in running_models):
available_models.append(k)
@@ -177,103 +179,47 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
index = llm_models.index(cur_llm_model)
else:
index = 0
- llm_model = st.selectbox("选择LLM模型:",
+ llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
key="llm_model",
)
- if (st.session_state.get("prev_llm_model") != llm_model
- and not is_lite
- and not llm_model in config_models.get("online", {})
- and not llm_model in config_models.get("langchain", {})
- and llm_model not in running_models):
- with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
- prev_model = st.session_state.get("prev_llm_model")
- r = api.change_llm_model(prev_model, llm_model)
- if msg := check_error_msg(r):
- st.error(msg)
- elif msg := check_success_msg(r):
- st.success(msg)
- st.session_state["prev_llm_model"] = llm_model
- index_prompt = {
- "LLM 对话": "llm_chat",
- "自定义Agent问答": "agent_chat",
- "搜索引擎问答": "search_engine_chat",
- "知识库问答": "knowledge_base_chat",
- "文件对话": "knowledge_base_chat",
- }
- prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
- prompt_template_name = prompt_templates_kb_list[0]
- if "prompt_template_select" not in st.session_state:
- st.session_state.prompt_template_select = prompt_templates_kb_list[0]
- def prompt_change():
- text = f"已切换为 {prompt_template_name} 模板。"
- st.toast(text)
+ # 传入后端的内容
+ model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
- prompt_template_select = st.selectbox(
- "请选择Prompt模板:",
- prompt_templates_kb_list,
- index=0,
- on_change=prompt_change,
- key="prompt_template_select",
- )
- prompt_template_name = st.session_state.prompt_template_select
- temperature = st.slider("Temperature:", 0.0, 2.0, TEMPERATURE, 0.05)
- history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
+ for key in LLM_MODEL_CONFIG:
+ if key == 'llm_model':
+ continue
+ if LLM_MODEL_CONFIG[key]:
+ first_key = next(iter(LLM_MODEL_CONFIG[key]))
+ model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
- def on_kb_change():
- st.toast(f"已加载知识库: {st.session_state.selected_kb}")
+ if llm_model is not None:
+ model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
- if dialogue_mode == "知识库问答":
- with st.expander("知识库配置", True):
- kb_list = api.list_knowledge_bases()
- index = 0
- if DEFAULT_KNOWLEDGE_BASE in kb_list:
- index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
- selected_kb = st.selectbox(
- "请选择知识库:",
- kb_list,
- index=index,
- on_change=on_kb_change,
- key="selected_kb",
- )
- kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
-
- ## Bge 模型会超过1
- score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
- elif dialogue_mode == "文件对话":
- with st.expander("文件对话配置", True):
- files = st.file_uploader("上传知识文件:",
- [i for ls in LOADER_DICT.values() for i in ls],
- accept_multiple_files=True,
- )
- kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
-
- ## Bge 模型会超过1
- score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
- if st.button("开始上传", disabled=len(files) == 0):
- st.session_state["file_chat_id"] = upload_temp_docs(files, api)
- elif dialogue_mode == "搜索引擎问答":
- search_engine_list = api.list_search_engines()
- if DEFAULT_SEARCH_ENGINE in search_engine_list:
- index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
- else:
- index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
- with st.expander("搜索引擎配置", True):
- search_engine = st.selectbox(
- label="请选择搜索引擎",
- options=search_engine_list,
- index=index,
- )
- se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
+ print(model_config)
+ files = st.file_uploader("上传附件",
+ type=[i for ls in LOADER_DICT.values() for i in ls],
+ accept_multiple_files=True)
+ # if dialogue_mode == "文件对话":
+ # with st.expander("文件对话配置", True):
+ # files = st.file_uploader("上传知识文件:",
+ # [i for ls in LOADER_DICT.values() for i in ls],
+ # accept_multiple_files=True,
+ # )
+ # kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
+ # score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
+ # if st.button("开始上传", disabled=len(files) == 0):
+ # st.session_state["file_chat_id"] = upload_temp_docs(files, api)
# Display chat messages from history on app rerun
- chat_box.output_messages()
+
+ chat_box.output_messages()
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
def on_feedback(
@@ -297,140 +243,76 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
st.rerun()
else:
- history = get_messages_history(history_len)
+ history = get_messages_history(
+ model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
chat_box.user_say(prompt)
- if dialogue_mode == "LLM 对话":
- chat_box.ai_say("正在思考...")
- text = ""
- message_id = ""
- r = api.chat_chat(prompt,
- history=history,
- conversation_id=conversation_id,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature)
- for t in r:
- if error_msg := check_error_msg(t): # check whether error occured
- st.error(error_msg)
- break
- text += t.get("text", "")
- chat_box.update_msg(text)
- message_id = t.get("message_id", "")
+ chat_box.ai_say("正在思考...")
+ text = ""
+ message_id = ""
+ element_index = 0
+ for d in api.chat_chat(query=prompt,
+ history=history,
+ model_config=model_config,
+ conversation_id=conversation_id,
+ tool_config=selected_tool_configs,
+ ):
+ try:
+ d = json.loads(d)
+ except:
+ pass
+ message_id = d.get("message_id", "")
metadata = {
"message_id": message_id,
}
- chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
- chat_box.show_feedback(**feedback_kwargs,
- key=message_id,
- on_submit=on_feedback,
- kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
-
- elif dialogue_mode == "自定义Agent问答":
- if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
- chat_box.ai_say([
- f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n",
- Markdown("...", in_expander=True, title="思考过程", state="complete"),
-
- ])
- else:
- chat_box.ai_say([
- f"正在思考...",
- Markdown("...", in_expander=True, title="思考过程", state="complete"),
-
- ])
- text = ""
- ans = ""
- for d in api.agent_chat(prompt,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature,
- ):
- try:
- d = json.loads(d)
- except:
- pass
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- if chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=1)
- if chunk := d.get("final_answer"):
- ans += chunk
- chat_box.update_msg(ans, element_index=0)
- if chunk := d.get("tools"):
- text += "\n\n".join(d.get("tools", []))
- chat_box.update_msg(text, element_index=1)
- chat_box.update_msg(ans, element_index=0, streaming=False)
- chat_box.update_msg(text, element_index=1, streaming=False)
- elif dialogue_mode == "知识库问答":
- chat_box.ai_say([
- f"正在查询知识库 `{selected_kb}` ...",
- Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
- ])
- text = ""
- for d in api.knowledge_base_chat(prompt,
- knowledge_base_name=selected_kb,
- top_k=kb_top_k,
- score_threshold=score_threshold,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature):
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- elif chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=0)
- chat_box.update_msg(text, element_index=0, streaming=False)
- chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
- elif dialogue_mode == "文件对话":
- if st.session_state["file_chat_id"] is None:
- st.error("请先上传文件再进行对话")
- st.stop()
- chat_box.ai_say([
- f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
- Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
- ])
- text = ""
- for d in api.file_chat(prompt,
- knowledge_id=st.session_state["file_chat_id"],
- top_k=kb_top_k,
- score_threshold=score_threshold,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature):
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- elif chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=0)
- chat_box.update_msg(text, element_index=0, streaming=False)
- chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
- elif dialogue_mode == "搜索引擎问答":
- chat_box.ai_say([
- f"正在执行 `{search_engine}` 搜索...",
- Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
- ])
- text = ""
- for d in api.search_engine_chat(prompt,
- search_engine_name=search_engine,
- top_k=se_top_k,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature,
- split_result=se_top_k > 1):
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- elif chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=0)
- chat_box.update_msg(text, element_index=0, streaming=False)
- chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
+ if error_msg := check_error_msg(d):
+ st.error(error_msg)
+ if chunk := d.get("agent_action"):
+ chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete"))
+ element_index = 1
+ formatted_data = {
+ "action": chunk["tool_name"],
+ "action_input": chunk["tool_input"]
+ }
+ formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
+ text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"
+ chat_box.update_msg(text, element_index=element_index, metadata=metadata)
+ if chunk := d.get("text"):
+ text += chunk
+ chat_box.update_msg(text, element_index=element_index, metadata=metadata)
+ if chunk := d.get("agent_finish"):
+ element_index = 0
+ text = chunk
+ chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
+ chat_box.show_feedback(**feedback_kwargs,
+ key=message_id,
+ on_submit=on_feedback,
+ kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
+ # elif dialogue_mode == "文件对话":
+ # if st.session_state["file_chat_id"] is None:
+ # st.error("请先上传文件再进行对话")
+ # st.stop()
+ # chat_box.ai_say([
+ # f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
+ # Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
+ # ])
+ # text = ""
+ # for d in api.file_chat(prompt,
+ # knowledge_id=st.session_state["file_chat_id"],
+ # top_k=kb_top_k,
+ # score_threshold=score_threshold,
+ # history=history,
+ # model=llm_model,
+ # prompt_name=prompt_template_name,
+ # temperature=temperature):
+ # if error_msg := check_error_msg(d):
+ # st.error(error_msg)
+ # elif chunk := d.get("answer"):
+ # text += chunk
+ # chat_box.update_msg(text, element_index=0)
+ # chat_box.update_msg(text, element_index=0, streaming=False)
+ # chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
if st.session_state.get("need_rerun"):
st.session_state["need_rerun"] = False
st.rerun()
diff --git a/webui_pages/utils.py b/webui_pages/utils.py
index 818748d6..834873d4 100644
--- a/webui_pages/utils.py
+++ b/webui_pages/utils.py
@@ -3,18 +3,15 @@
from typing import *
from pathlib import Path
-# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
from configs import (
EMBEDDING_MODEL,
DEFAULT_VS_TYPE,
- LLM_MODELS,
- TEMPERATURE,
+ LLM_MODEL_CONFIG,
SCORE_THRESHOLD,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
VECTOR_SEARCH_TOP_K,
- SEARCH_ENGINE_TOP_K,
HTTPX_DEFAULT_TIMEOUT,
logger, log_verbose,
)
@@ -26,7 +23,6 @@ from io import BytesIO
from server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint
-from langchain_core._api import deprecated
set_httpx_config()
@@ -247,10 +243,6 @@ class ApiRequest:
response = self.post("/server/configs", **kwargs)
return self._get_response_value(response, as_json=True)
- def list_search_engines(self, **kwargs) -> List:
- response = self.post("/server/list_search_engines", **kwargs)
- return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
-
def get_prompt_template(
self,
type: str = "llm_chat",
@@ -272,10 +264,8 @@ class ApiRequest:
history_len: int = -1,
history: List[Dict] = [],
stream: bool = True,
- model: str = LLM_MODELS[0],
- temperature: float = TEMPERATURE,
- max_tokens: int = None,
- prompt_name: str = "default",
+ model_config: Dict = None,
+ tool_config: Dict = None,
**kwargs,
):
'''
@@ -287,10 +277,8 @@ class ApiRequest:
"history_len": history_len,
"history": history,
"stream": stream,
- "model_name": model,
- "temperature": temperature,
- "max_tokens": max_tokens,
- "prompt_name": prompt_name,
+ "model_config": model_config,
+ "tool_config": tool_config,
}
# print(f"received input message:")
@@ -299,78 +287,6 @@ class ApiRequest:
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response, as_json=True)
- @deprecated(
- since="0.3.0",
- message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
- removal="0.3.0")
- def agent_chat(
- self,
- query: str,
- history: List[Dict] = [],
- stream: bool = True,
- model: str = LLM_MODELS[0],
- temperature: float = TEMPERATURE,
- max_tokens: int = None,
- prompt_name: str = "default",
- ):
- '''
- 对应api.py/chat/agent_chat 接口
- '''
- data = {
- "query": query,
- "history": history,
- "stream": stream,
- "model_name": model,
- "temperature": temperature,
- "max_tokens": max_tokens,
- "prompt_name": prompt_name,
- }
-
- # print(f"received input message:")
- # pprint(data)
-
- response = self.post("/chat/agent_chat", json=data, stream=True)
- return self._httpx_stream2generator(response, as_json=True)
-
- def knowledge_base_chat(
- self,
- query: str,
- knowledge_base_name: str,
- top_k: int = VECTOR_SEARCH_TOP_K,
- score_threshold: float = SCORE_THRESHOLD,
- history: List[Dict] = [],
- stream: bool = True,
- model: str = LLM_MODELS[0],
- temperature: float = TEMPERATURE,
- max_tokens: int = None,
- prompt_name: str = "default",
- ):
- '''
- 对应api.py/chat/knowledge_base_chat接口
- '''
- data = {
- "query": query,
- "knowledge_base_name": knowledge_base_name,
- "top_k": top_k,
- "score_threshold": score_threshold,
- "history": history,
- "stream": stream,
- "model_name": model,
- "temperature": temperature,
- "max_tokens": max_tokens,
- "prompt_name": prompt_name,
- }
-
- # print(f"received input message:")
- # pprint(data)
-
- response = self.post(
- "/chat/knowledge_base_chat",
- json=data,
- stream=True,
- )
- return self._httpx_stream2generator(response, as_json=True)
-
def upload_temp_docs(
self,
files: List[Union[str, Path, bytes]],
@@ -416,8 +332,8 @@ class ApiRequest:
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
- model: str = LLM_MODELS[0],
- temperature: float = TEMPERATURE,
+ model: str = None,
+ temperature: float = 0.9,
max_tokens: int = None,
prompt_name: str = "default",
):
@@ -444,50 +360,6 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
- @deprecated(
- since="0.3.0",
- message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
- removal="0.3.0"
- )
- def search_engine_chat(
- self,
- query: str,
- search_engine_name: str,
- top_k: int = SEARCH_ENGINE_TOP_K,
- history: List[Dict] = [],
- stream: bool = True,
- model: str = LLM_MODELS[0],
- temperature: float = TEMPERATURE,
- max_tokens: int = None,
- prompt_name: str = "default",
- split_result: bool = False,
- ):
- '''
- 对应api.py/chat/search_engine_chat接口
- '''
- data = {
- "query": query,
- "search_engine_name": search_engine_name,
- "top_k": top_k,
- "history": history,
- "stream": stream,
- "model_name": model,
- "temperature": temperature,
- "max_tokens": max_tokens,
- "prompt_name": prompt_name,
- "split_result": split_result,
- }
-
- # print(f"received input message:")
- # pprint(data)
-
- response = self.post(
- "/chat/search_engine_chat",
- json=data,
- stream=True,
- )
- return self._httpx_stream2generator(response, as_json=True)
-
# 知识库相关操作
def list_knowledge_bases(
@@ -769,7 +641,7 @@ class ApiRequest:
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
'''
从服务器上获取当前运行的LLM模型。
- 当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。
+ 当 local_first=True 时,优先返回运行中的本地模型,否则优先按 LLM_MODEL_CONFIG['llm_model']配置顺序返回。
返回类型为(model_name, is_local_model)
'''
@@ -779,7 +651,7 @@ class ApiRequest:
return "", False
model = ""
- for m in LLM_MODELS:
+ for m in LLM_MODEL_CONFIG['llm_model']:
if m not in running_models:
continue
is_local = not running_models[m].get("online_api")
@@ -789,7 +661,7 @@ class ApiRequest:
model = m
break
- if not model: # LLM_MODELS中配置的模型都不在running_models里
+ if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
@@ -800,7 +672,7 @@ class ApiRequest:
return "", False
model = ""
- for m in LLM_MODELS:
+ for m in LLM_MODEL_CONFIG['llm_model']:
if m not in running_models:
continue
is_local = not running_models[m].get("online_api")
@@ -810,7 +682,7 @@ class ApiRequest:
model = m
break
- if not model: # LLM_MODELS中配置的模型都不在running_models里
+ if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
@@ -852,15 +724,6 @@ class ApiRequest:
)
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
- def list_search_engines(self) -> List[str]:
- '''
- 获取服务器支持的搜索引擎
- '''
- response = self.post(
- "/server/list_search_engines",
- )
- return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
-
def stop_llm_model(
self,
model_name: str,