From 253168a18785717d8168a5a4b1ea417cb190e502 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Tue, 5 Dec 2023 17:17:53 +0800 Subject: [PATCH] Dev (#2280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复Azure 不设置Max token的bug * 重写agent 1. 修改Agent实现方式,支持多参数,仅剩 ChatGLM3-6b和 OpenAI GPT4 支持,剩余模型将在暂时缺席Agent功能 2. 删除agent_chat 集成到llm_chat中 3. 重写大部分工具,适应新Agent * 更新架构 * 删除web_chat,自动融合 * 移除所有聊天,都变成Agent控制 * 更新配置文件 * 更新配置模板和提示词 * 更改参数选择bug --- chains/llmchain_with_history.py | 22 -- configs/model_config.py.example | 302 +++++++++------ configs/prompt_config.py.example | 126 +++---- configs/server_config.py.example | 2 +- docs/ES部署指南.md | 29 -- requirements.txt | 12 +- requirements_api.txt | 12 +- server/{chat => }/__init__.py | 0 server/agent/__init__.py | 4 - server/agent/agent_factory/__init__.py | 1 + .../glm3_agent.py} | 14 +- server/agent/custom_template.py | 67 ---- server/agent/model_contain.py | 6 - server/agent/tools/__init__.py | 11 - server/agent/tools/calculate.py | 76 ---- server/agent/tools/search_internet.py | 37 -- .../tools/search_knowledgebase_complex.py | 287 -------------- .../agent/tools/search_knowledgebase_once.py | 234 ------------ .../tools/search_knowledgebase_simple.py | 32 -- server/agent/tools_factory/__init__.py | 8 + .../agent/{tools => tools_factory}/arxiv.py | 0 server/agent/tools_factory/calculate.py | 23 ++ server/agent/tools_factory/search_internet.py | 99 +++++ .../search_local_knowledgebase.py | 40 ++ .../search_youtube.py | 2 +- .../agent/{tools => tools_factory}/shell.py | 2 +- server/agent/tools_factory/tools_registry.py | 59 +++ .../{tools => tools_factory}/weather_check.py | 12 +- .../agent/{tools => tools_factory}/wolfram.py | 2 +- server/agent/tools_select.py | 55 --- server/api.py | 20 +- .../agent_callback_handler.py} | 143 +++---- server/chat/agent_chat.py | 178 --------- server/chat/chat.py | 175 ++++++--- server/chat/completion.py | 10 +- server/chat/file_chat.py | 11 +- server/chat/knowledge_base_chat.py | 147 -------- server/chat/search_engine_chat.py | 208 ---------- server/db/models/conversation_model.py | 1 - server/db/models/message_model.py | 1 - server/knowledge_base/kb_summary_api.py | 13 +- server/knowledge_base/utils.py | 9 +- server/llm_api.py | 14 +- server/model_workers/base.py | 4 +- server/model_workers/qwen.py | 5 - server/utils.py | 28 +- startup.py | 48 ++- tests/api/test_server_state_api.py | 10 +- tests/api/test_stream_chat_api.py | 26 -- tests/test_online_api.py | 13 +- webui_pages/dialogue/dialogue.py | 354 ++++++------------ webui_pages/utils.py | 161 +------- 52 files changed, 862 insertions(+), 2293 deletions(-) delete mode 100644 chains/llmchain_with_history.py delete mode 100644 docs/ES部署指南.md rename server/{chat => }/__init__.py (100%) delete mode 100644 server/agent/__init__.py create mode 100644 server/agent/agent_factory/__init__.py rename server/agent/{custom_agent/ChatGLM3Agent.py => agent_factory/glm3_agent.py} (94%) delete mode 100644 server/agent/custom_template.py delete mode 100644 server/agent/model_contain.py delete mode 100644 server/agent/tools/__init__.py delete mode 100644 server/agent/tools/calculate.py delete mode 100644 server/agent/tools/search_internet.py delete mode 100644 server/agent/tools/search_knowledgebase_complex.py delete mode 100644 server/agent/tools/search_knowledgebase_once.py delete mode 100644 server/agent/tools/search_knowledgebase_simple.py create mode 100644 server/agent/tools_factory/__init__.py rename server/agent/{tools => tools_factory}/arxiv.py (100%) create mode 100644 server/agent/tools_factory/calculate.py create mode 100644 server/agent/tools_factory/search_internet.py create mode 100644 server/agent/tools_factory/search_local_knowledgebase.py rename server/agent/{tools => tools_factory}/search_youtube.py (80%) rename server/agent/{tools => tools_factory}/shell.py (72%) create mode 100644 server/agent/tools_factory/tools_registry.py rename server/agent/{tools => tools_factory}/weather_check.py (76%) rename server/agent/{tools => tools_factory}/wolfram.py (84%) delete mode 100644 server/agent/tools_select.py rename server/{agent/callbacks.py => callback_handler/agent_callback_handler.py} (53%) delete mode 100644 server/chat/agent_chat.py delete mode 100644 server/chat/knowledge_base_chat.py delete mode 100644 server/chat/search_engine_chat.py diff --git a/chains/llmchain_with_history.py b/chains/llmchain_with_history.py deleted file mode 100644 index 2a845086..00000000 --- a/chains/llmchain_with_history.py +++ /dev/null @@ -1,22 +0,0 @@ -from server.utils import get_ChatOpenAI -from configs.model_config import LLM_MODELS, TEMPERATURE -from langchain.chains import LLMChain -from langchain.prompts.chat import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, -) - -model = get_ChatOpenAI(model_name=LLM_MODELS[0], temperature=TEMPERATURE) - - -human_prompt = "{input}" -human_message_template = HumanMessagePromptTemplate.from_template(human_prompt) - -chat_prompt = ChatPromptTemplate.from_messages( - [("human", "我们来玩成语接龙,我先来,生龙活虎"), - ("ai", "虎头虎脑"), - ("human", "{input}")]) - - -chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True) -print(chain({"input": "恼羞成怒"})) \ No newline at end of file diff --git a/configs/model_config.py.example b/configs/model_config.py.example index c42066d3..50314e56 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -1,14 +1,6 @@ import os - -# 可以指定一个绝对路径,统一存放所有的Embedding和LLM模型。 -# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录。 -# 如果模型目录名称和 MODEL_PATH 中的 key 或 value 相同,程序会自动检测加载,无需修改 MODEL_PATH 中的路径。 MODEL_ROOT_PATH = "" - -# 选用的 Embedding 名称 -EMBEDDING_MODEL = "bge-large-zh-v1.5" - -# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。 +EMBEDDING_MODEL = "bge-large-zh-v1.5" # bge-large-zh EMBEDDING_DEVICE = "auto" # 选用的reranker模型 @@ -21,65 +13,170 @@ RERANKER_MAX_LENGTH = 1024 EMBEDDING_KEYWORD_FILE = "keywords.txt" EMBEDDING_MODEL_OUTPUT_PATH = "output" -# 要运行的 LLM 名称,可以包括本地模型和在线模型。列表中本地模型将在启动项目时全部加载。 -# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。 -# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。 -# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。 +LLM_MODEL_CONFIG = { + # 意图识别不需要输出,模型后台知道就行 + "preprocess_model": { + "zhipu-api": { + "temperature": 0.4, + "max_tokens": 2048, + "history_len": 100, + "prompt_name": "default", + "callbacks": False + }, + }, + "llm_model": { + "chatglm3-6b": { + "temperature": 0.9, + "max_tokens": 4096, + "history_len": 3, + "prompt_name": "default", + "callbacks": True + }, + "zhipu-api": { + "temperature": 0.9, + "max_tokens": 4000, + "history_len": 5, + "prompt_name": "default", + "callbacks": True + }, + "Qwen-1_8B-Chat": { + "temperature": 0.4, + "max_tokens": 2048, + "history_len": 100, + "prompt_name": "default", + "callbacks": False + }, + }, + "action_model": { + "chatglm3-6b": { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "ChatGLM3", + "callbacks": True + }, + "zhipu-api": { + "temperature": 0.01, + "max_tokens": 2096, + "history_len": 5, + "prompt_name": "ChatGLM3", + "callbacks": True + }, + }, + "postprocess_model": { + "chatglm3-6b": { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "default", + "callbacks": True + } + }, -LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] -Agent_MODEL = None +} + +TOOL_CONFIG = { + "search_local_knowledgebase": { + "use": True, + "top_k": 10, + "score_threshold": 1, + "conclude_prompt": { + "with_result": + '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' + '不允许在答案中添加编造成分,答案请使用中文。 \n' + '<已知信息>{{ context }}\n' + '<问题>{{ question }}\n', + "without_result": + '请你根据我的提问回答我的问题:\n' + '{{ question }}\n' + '请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n', + } + }, + "search_internet": { + "use": True, + "search_engine_name": "bing", + "search_engine_config": + { + "bing": { + "result_len": 3, + "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", + "bing_key": "", + }, + "metaphor": { + "result_len": 3, + "metaphor_api_key": "", + "split_result": False, + "chunk_size": 500, + "chunk_overlap": 0, + }, + "duckduckgo": { + "result_len": 3 + } + }, + "top_k": 10, + "verbose": "Origin", + "conclude_prompt": + "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " + "\n<已知信息>{{ context }}\n" + "<问题>\n" + "{{ question }}\n" + "\n" + }, + "arxiv": { + "use": True, + }, + "shell": { + "use": True, + }, + "weather_check": { + "use": True, + "api-key": "", + }, + "search_youtube": { + "use": False, + }, + "wolfram": { + "use": False, + }, + "calculate": { + "use": False, + }, + +} # LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。 LLM_DEVICE = "auto" -HISTORY_LEN = 3 - -MAX_TOKENS = 2048 - -TEMPERATURE = 0.7 - ONLINE_LLM_MODEL = { "openai-api": { - "model_name": "gpt-4", + "model_name": "gpt-4-1106-preview", "api_base_url": "https://api.openai.com/v1", "api_key": "", "openai_proxy": "", }, - - # 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn "zhipu-api": { "api_key": "", - "version": "glm-4", + "version": "chatglm_turbo", "provider": "ChatGLMWorker", }, - - # 具体注册及api key获取请前往 https://api.minimax.chat/ "minimax-api": { "group_id": "", "api_key": "", "is_pro": False, "provider": "MiniMaxWorker", }, - - # 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/ "xinghuo-api": { "APPID": "", "APISecret": "", "api_key": "", - "version": "v3.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.5","v3.0", "v2.0", "v1.5" + "version": "v3.0", "provider": "XingHuoWorker", }, - - # 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf "qianfan-api": { - "version": "ERNIE-Bot", # 注意大小写。当前支持 "ERNIE-Bot" 或 "ERNIE-Bot-turbo", 更多的见官方文档。 - "version_url": "", # 也可以不填写version,直接填写在千帆申请模型发布的API地址 + "version": "ernie-bot-4", + "version_url": "", "api_key": "", "secret_key": "", "provider": "QianFanWorker", }, - - # 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379 "fangzhou-api": { "version": "chatglm-6b-model", "version_url": "", @@ -87,28 +184,22 @@ ONLINE_LLM_MODEL = { "secret_key": "", "provider": "FangZhouWorker", }, - - # 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details "qwen-api": { "version": "qwen-max", "api_key": "", "provider": "QwenWorker", "embed_model": "text-embedding-v1" # embedding 模型名称 }, - - # 百川 API,申请方式请参考 https://www.baichuan-ai.com/home#api-enter "baichuan-api": { "version": "Baichuan2-53B", "api_key": "", "secret_key": "", "provider": "BaiChuanWorker", }, - - # Azure API "azure-api": { - "deployment_name": "", # 部署容器的名字 - "resource_name": "", # https://{resource_name}.openai.azure.com/openai/ 填写resource_name的部分,其他部分不要填写 - "api_version": "", # API的版本,不是模型版本 + "deployment_name": "", + "resource_name": "", + "api_version": "2023-07-01-preview", "api_key": "", "provider": "AzureWorker", }, @@ -153,50 +244,33 @@ MODEL_PATH = { "bge-small-zh": "BAAI/bge-small-zh", "bge-base-zh": "BAAI/bge-base-zh", - "bge-large-zh": "BAAI/bge-large-zh", + "bge-large-zh": "/media/zr/Data/Models/Embedding/bge-large-zh", "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", "bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5", - "bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5", - - "bge-m3": "BAAI/bge-m3", - + "bge-large-zh-v1.5": "/Models/bge-large-zh-v1.5", "piccolo-base-zh": "sensenova/piccolo-base-zh", "piccolo-large-zh": "sensenova/piccolo-large-zh", - "nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large", - "text-embedding-ada-002": "your OPENAI_API_KEY", + "nlp_gte_sentence-embedding_chinese-large": "/Models/nlp_gte_sentence-embedding_chinese-large", + "text-embedding-ada-002": "Just write your OpenAI key like "sk-o3IGBhC9g8AiFvTGWVKsT*****" ", }, "llm_model": { "chatglm2-6b": "THUDM/chatglm2-6b", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", - "chatglm3-6b": "THUDM/chatglm3-6b", + "chatglm3-6b": "/Models/chatglm3-6b", "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", - "Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat", - "Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin", - "Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat", + "Yi-34B-Chat": "/data/share/models/Yi-34B-Chat", + "BlueLM-7B-Chat": "/Models/BlueLM-7B-Chat", - "Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf", - "Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf", - "Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf", + "baichuan2-13b": "/media/zr/Data/Models/LLM/Baichuan2-13B-Chat", + "baichuan2-7b": "/media/zr/Data/Models/LLM/Baichuan2-7B-Chat", + "baichuan-7b": "baichuan-inc/Baichuan-7B", + "baichuan-13b": "baichuan-inc/Baichuan-13B", + 'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat', - "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", - "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", - "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", - "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", - - # Qwen1.5 模型 VLLM可能出现问题 - "Qwen1.5-0.5B-Chat": "Qwen/Qwen1.5-0.5B-Chat", - "Qwen1.5-1.8B-Chat": "Qwen/Qwen1.5-1.8B-Chat", - "Qwen1.5-4B-Chat": "Qwen/Qwen1.5-4B-Chat", - "Qwen1.5-7B-Chat": "Qwen/Qwen1.5-7B-Chat", - "Qwen1.5-14B-Chat": "Qwen/Qwen1.5-14B-Chat", - "Qwen1.5-72B-Chat": "Qwen/Qwen1.5-72B-Chat", - - "baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat", - "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat", - "baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat", - "baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat", + "aquila-7b": "BAAI/Aquila-7B", + "aquilachat-7b": "BAAI/AquilaChat-7B", "internlm-7b": "internlm/internlm-7b", "internlm-chat-7b": "internlm/internlm-chat-7b", @@ -235,42 +309,43 @@ MODEL_PATH = { "oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "dolly-v2-12b": "databricks/dolly-v2-12b", "stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b", + + "Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf", + "Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf", + "open_llama_13b": "openlm-research/open_llama_13b", + "vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3", + "koala": "young-geng/koala", + + "mpt-7b": "mosaicml/mpt-7b", + "mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter", + "mpt-30b": "mosaicml/mpt-30b", + "opt-66b": "facebook/opt-66b", + "opt-iml-max-30b": "facebook/opt-iml-max-30b", + + "Qwen-1_8B-Chat":"Qwen/Qwen-1_8B-Chat" + "Qwen-7B": "Qwen/Qwen-7B", + "Qwen-14B": "Qwen/Qwen-14B", + "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", + "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", + "Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", # 确保已经安装了auto-gptq optimum flash-attn + "Qwen-14B-Chat-Int4": "/media/zr/Data/Models/LLM/Qwen-14B-Chat-Int4", # 确保已经安装了auto-gptq optimum flash-attn }, - "reranker": { - "bge-reranker-large": "BAAI/bge-reranker-large", - "bge-reranker-base": "BAAI/bge-reranker-base", - } -} - -# 通常情况下不需要更改以下内容 - -# nltk 模型存储路径 NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") # 使用VLLM可能导致模型推理能力下降,无法完成Agent任务 VLLM_MODEL_DICT = { - "chatglm2-6b": "THUDM/chatglm2-6b", - "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", - "chatglm3-6b": "THUDM/chatglm3-6b", - "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", + "aquila-7b": "BAAI/Aquila-7B", + "aquilachat-7b": "BAAI/AquilaChat-7B", - "Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf", - "Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf", - "Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf", + "baichuan-7b": "baichuan-inc/Baichuan-7B", + "baichuan-13b": "baichuan-inc/Baichuan-13B", + 'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat', - "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", - "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", - "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", - "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", - - "baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat", - "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat", - "baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat", - "baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat", - - "BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat", - "BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k", + 'chatglm2-6b': 'THUDM/chatglm2-6b', + 'chatglm2-6b-32k': 'THUDM/chatglm2-6b-32k', + 'chatglm3-6b': 'THUDM/chatglm3-6b', + 'chatglm3-6b-32k': 'THUDM/chatglm3-6b-32k', "internlm-7b": "internlm/internlm-7b", "internlm-chat-7b": "internlm/internlm-chat-7b", @@ -301,14 +376,13 @@ VLLM_MODEL_DICT = { "opt-66b": "facebook/opt-66b", "opt-iml-max-30b": "facebook/opt-iml-max-30b", -} + "Qwen-7B": "Qwen/Qwen-7B", + "Qwen-14B": "Qwen/Qwen-14B", + "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", + "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", -SUPPORT_AGENT_MODEL = [ - "openai-api", # GPT4 模型 - "qwen-api", # Qwen Max模型 - "zhipu-api", # 智谱AI GLM4模型 - "Qwen", # 所有Qwen系列本地模型 - "chatglm3-6b", - "internlm2-chat-20b", - "Orion-14B-Chat-Plugin", -] + "agentlm-7b": "THUDM/agentlm-7b", + "agentlm-13b": "THUDM/agentlm-13b", + "agentlm-70b": "THUDM/agentlm-70b", + +} diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index 6fb6996c..e7def5a6 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -1,26 +1,19 @@ -# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号 -# 本配置文件支持热加载,修改prompt模板后无需重启服务。 - -# LLM对话支持的变量: -# - input: 用户输入内容 - -# 知识库和搜索引擎对话支持的变量: -# - context: 从检索结果拼接的知识文本 -# - question: 用户提出的问题 - -# Agent对话支持的变量: - -# - tools: 可用的工具列表 -# - tool_names: 可用的工具名称列表 -# - history: 用户和Agent的对话历史 -# - input: 用户输入内容 -# - agent_scratchpad: Agent的思维记录 - PROMPT_TEMPLATES = { - "llm_chat": { + "preprocess_model": { + "default": + '请你根据我的描述和我们对话的历史,来判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。你只要回答一个数字。1 或者 0,1代表需要使用工具,0代表不需要使用工具。\n' + '以下几种情况要使用工具,请返回1\n' + '1. 实时性的问题,例如天气,日期,地点等信息\n' + '2. 需要数学计算的问题\n' + '3. 需要查询数据,地点等精确数据\n' + '4. 需要行业知识的问题\n' + '' + '{input}' + '' + }, + "llm_model": { "default": '{{ input }}', - "with_history": 'The following is a friendly conversation between a human and an AI. ' 'The AI is talkative and provides lots of specific details from its context. ' @@ -29,72 +22,42 @@ PROMPT_TEMPLATES = { '{history}\n' 'Human: {input}\n' 'AI:', - - "py": - '你是一个聪明的代码助手,请你给我写出简单的py代码。 \n' - '{{ input }}', }, - - - "knowledge_base_chat": { - "default": - '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,' - '不允许在答案中添加编造成分,答案请使用中文。 \n' - '<已知信息>{{ context }}\n' - '<问题>{{ question }}\n', - - "text": - '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 \n' - '<已知信息>{{ context }}\n' - '<问题>{{ question }}\n', - - "empty": # 搜不到知识库的时候使用 - '请你回答我的问题:\n' - '{{ question }}\n\n', - }, - - - "search_engine_chat": { - "default": - '<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。' - '如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 \n' - '<已知信息>{{ context }}\n' - '<问题>{{ question }}\n', - - "search": - '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 \n' - '<已知信息>{{ context }}\n' - '<问题>{{ question }}\n', - }, - - - "agent_chat": { - "default": - 'Answer the following questions as best you can. If it is in order, you can use some tools appropriately. ' - 'You have access to the following tools:\n\n' - '{tools}\n\n' - 'Use the following format:\n' - 'Question: the input question you must answer1\n' - 'Thought: you should always think about what to do and what tools to use.\n' - 'Action: the action to take, should be one of [{tool_names}]\n' - 'Action Input: the input to the action\n' + "action_model": { + "GPT-4": + 'Answer the following questions as best you can. You have access to the following tools:\n' + 'The way you use the tools is by specifying a json blob.\n' + 'Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).\n' + 'The only values that should be in the "action" field are: {tool_names}\n' + 'The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:\n' + '```\n\n' + '{{{{\n' + ' "action": $TOOL_NAME,\n' + ' "action_input": $INPUT\n' + '}}}}\n' + '```\n\n' + 'ALWAYS use the following format:\n' + 'Question: the input question you must answer\n' + 'Thought: you should always think about what to do\n' + 'Action:\n' + '```\n\n' + '$JSON_BLOB' + '```\n\n' 'Observation: the result of the action\n' - '... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n' + '... (this Thought/Action/Observation can repeat N times)\n' 'Thought: I now know the final answer\n' 'Final Answer: the final answer to the original input question\n' - 'Begin!\n\n' - 'history: {history}\n\n' - 'Question: {input}\n\n' - 'Thought: {agent_scratchpad}\n', - + 'Begin! Reminder to always use the exact characters `Final Answer` when responding.\n' + 'history: {history}\n' + 'Question:{input}\n' + 'Thought:{agent_scratchpad}\n', "ChatGLM3": - 'You can answer using the tools, or answer directly using your knowledge without using the tools. ' - 'Respond to the human as helpfully and accurately as possible.\n' + 'You can answer using the tools.Respond to the human as helpfully and accurately as possible.\n' 'You have access to the following tools:\n' '{tools}\n' - 'Use a json blob to specify a tool by providing an action key (tool name) ' + 'Use a json blob to specify a tool by providing an action key (tool name)\n' 'and an action_input key (tool input).\n' - 'Valid "action" values: "Final Answer" or [{tool_names}]' + 'Valid "action" values: "Final Answer" or [{tool_names}]\n' 'Provide only ONE action per $JSON_BLOB, as shown:\n\n' '```\n' '{{{{\n' @@ -118,10 +81,13 @@ PROMPT_TEMPLATES = { ' "action": "Final Answer",\n' ' "action_input": "Final response to human"\n' '}}}}\n' - 'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. ' + 'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary.\n' 'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n' 'history: {history}\n\n' 'Question: {input}\n\n' - 'Thought: {agent_scratchpad}', + 'Thought: {agent_scratchpad}\n', + }, + "postprocess_model": { + "default": "{{input}}", } } diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 9bbb8b49..3cc51dd5 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -57,7 +57,7 @@ FSCHAT_MODEL_WORKERS = { # "awq_ckpt": None, # "awq_wbits": 16, # "awq_groupsize": -1, - # "model_names": LLM_MODELS, + # "model_names": None, # "conv_template": None, # "limit_worker_concurrency": 5, # "stream_interval": 2, diff --git a/docs/ES部署指南.md b/docs/ES部署指南.md deleted file mode 100644 index f4615826..00000000 --- a/docs/ES部署指南.md +++ /dev/null @@ -1,29 +0,0 @@ - -# 实现基于ES的数据插入、检索、删除、更新 -```shell -author: 唐国梁Tommy -e-mail: flytang186@qq.com - -如果遇到任何问题,可以与我联系,我这边部署后服务是没有问题的。 -``` - -## 第1步:ES docker部署 -```shell -docker network create elastic -docker run -id --name elasticsearch --net elastic -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "xpack.security.enabled=false" -e "xpack.security.http.ssl.enabled=false" -t docker.elastic.co/elasticsearch/elasticsearch:8.8.2 -``` - -### 第2步:Kibana docker部署 -**注意:Kibana版本与ES保持一致** -```shell -docker pull docker.elastic.co/kibana/kibana:{version} -docker run --name kibana --net elastic -p 5601:5601 docker.elastic.co/kibana/kibana:{version} -``` - -### 第3步:核心代码 -```shell -1. 核心代码路径 -server/knowledge_base/kb_service/es_kb_service.py - -2. 需要在 configs/model_config.py 中 配置 ES参数(IP, PORT)等; -``` \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 30b82248..bb341dd6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,7 @@ -torch==2.1.2 -torchvision==0.16.2 -torchaudio==2.1.2 -xformers==0.0.23.post1 -transformers==4.37.2 -sentence_transformers==2.2.2 -langchain==0.0.354 -langchain-experimental==0.0.47 +# API requirements + +langchain>=0.0.346 +langchain-experimental>=0.0.42 pydantic==1.10.13 fschat==0.2.35 openai==1.9.0 diff --git a/requirements_api.txt b/requirements_api.txt index fbb00357..de38487b 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,11 +1,7 @@ -torch~=2.1.2 -torchvision~=0.16.2 -torchaudio~=2.1.2 -xformers>=0.0.23.post1 -transformers==4.37.2 -sentence_transformers==2.2.2 -langchain==0.0.354 -langchain-experimental==0.0.47 +# API requirements + +langchain>=0.0.346 +langchain-experimental>=0.0.42 pydantic==1.10.13 fschat==0.2.35 openai~=1.9.0 diff --git a/server/chat/__init__.py b/server/__init__.py similarity index 100% rename from server/chat/__init__.py rename to server/__init__.py diff --git a/server/agent/__init__.py b/server/agent/__init__.py deleted file mode 100644 index 0de21612..00000000 --- a/server/agent/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .model_contain import * -from .callbacks import * -from .custom_template import * -from .tools import * \ No newline at end of file diff --git a/server/agent/agent_factory/__init__.py b/server/agent/agent_factory/__init__.py new file mode 100644 index 00000000..4729de78 --- /dev/null +++ b/server/agent/agent_factory/__init__.py @@ -0,0 +1 @@ +from .glm3_agent import initialize_glm3_agent \ No newline at end of file diff --git a/server/agent/custom_agent/ChatGLM3Agent.py b/server/agent/agent_factory/glm3_agent.py similarity index 94% rename from server/agent/custom_agent/ChatGLM3Agent.py rename to server/agent/agent_factory/glm3_agent.py index 65f57567..00c4b5ce 100644 --- a/server/agent/custom_agent/ChatGLM3Agent.py +++ b/server/agent/agent_factory/glm3_agent.py @@ -3,6 +3,14 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file """ from __future__ import annotations +import yaml +from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser +from langchain.memory import ConversationBufferWindowMemory +from typing import Any, List, Sequence, Tuple, Optional, Union +import os +from langchain.agents.agent import Agent +from langchain.chains.llm import LLMChain +from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate import json import logging from typing import Any, List, Sequence, Tuple, Optional, Union @@ -22,6 +30,7 @@ from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import BaseCallbackManager from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool +from pydantic.schema import model_schema HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" logger = logging.getLogger(__name__) @@ -42,8 +51,9 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser): if "tool_call" in text: action_end = text.find("```") action = text[:action_end].strip() + params_str_start = text.find("(") + 1 - params_str_end = text.rfind(")") + params_str_end = text.find(")") params_str = text[params_str_start:params_str_end] params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param] @@ -225,4 +235,4 @@ def initialize_glm3_agent( memory=memory, tags=tags_, **kwargs, - ) \ No newline at end of file + ) diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py deleted file mode 100644 index 1431d2a0..00000000 --- a/server/agent/custom_template.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations -from langchain.agents import Tool, AgentOutputParser -from langchain.prompts import StringPromptTemplate -from typing import List -from langchain.schema import AgentAction, AgentFinish - -from configs import SUPPORT_AGENT_MODEL -from server.agent import model_container -class CustomPromptTemplate(StringPromptTemplate): - template: str - tools: List[Tool] - - def format(self, **kwargs) -> str: - intermediate_steps = kwargs.pop("intermediate_steps") - thoughts = "" - for action, observation in intermediate_steps: - thoughts += action.log - thoughts += f"\nObservation: {observation}\nThought: " - kwargs["agent_scratchpad"] = thoughts - kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) - kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) - return self.template.format(**kwargs) - -class CustomOutputParser(AgentOutputParser): - begin: bool = False - def __init__(self): - super().__init__() - self.begin = True - - def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction: - if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin: - self.begin = False - stop_words = ["Observation:"] - min_index = len(llm_output) - for stop_word in stop_words: - index = llm_output.find(stop_word) - if index != -1 and index < min_index: - min_index = index - llm_output = llm_output[:min_index] - - if "Final Answer:" in llm_output: - self.begin = True - return AgentFinish( - return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()}, - log=llm_output, - ) - parts = llm_output.split("Action:") - if len(parts) < 2: - return AgentFinish( - return_values={"output": f"调用agent工具失败,该回答为大模型自身能力的回答:\n\n `{llm_output}`"}, - log=llm_output, - ) - - action = parts[1].split("Action Input:")[0].strip() - action_input = parts[1].split("Action Input:")[1].strip() - try: - ans = AgentAction( - tool=action, - tool_input=action_input.strip(" ").strip('"'), - log=llm_output - ) - return ans - except: - return AgentFinish( - return_values={"output": f"调用agent失败: `{llm_output}`"}, - log=llm_output, - ) diff --git a/server/agent/model_contain.py b/server/agent/model_contain.py deleted file mode 100644 index 0141ad03..00000000 --- a/server/agent/model_contain.py +++ /dev/null @@ -1,6 +0,0 @@ -class ModelContainer: - def __init__(self): - self.MODEL = None - self.DATABASE = None - -model_container = ModelContainer() diff --git a/server/agent/tools/__init__.py b/server/agent/tools/__init__.py deleted file mode 100644 index 21dfd973..00000000 --- a/server/agent/tools/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -## 导入所有的工具类 -from .search_knowledgebase_simple import search_knowledgebase_simple -from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput -from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput -from .calculate import calculate, CalculatorInput -from .weather_check import weathercheck, WeatherInput -from .shell import shell, ShellInput -from .search_internet import search_internet, SearchInternetInput -from .wolfram import wolfram, WolframInput -from .search_youtube import search_youtube, YoutubeInput -from .arxiv import arxiv, ArxivInput diff --git a/server/agent/tools/calculate.py b/server/agent/tools/calculate.py deleted file mode 100644 index bb0cbcca..00000000 --- a/server/agent/tools/calculate.py +++ /dev/null @@ -1,76 +0,0 @@ -from langchain.prompts import PromptTemplate -from langchain.chains import LLMMathChain -from server.agent import model_container -from pydantic import BaseModel, Field - -_PROMPT_TEMPLATE = """ -将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。 -问题: ${{包含数学问题的问题。}} -```text -${{解决问题的单行数学表达式}} -``` -...numexpr.evaluate(query)... -```output -${{运行代码的输出}} -``` -答案: ${{答案}} - -这是两个例子: - -问题: 37593 * 67是多少? -```text -37593 * 67 -``` -...numexpr.evaluate("37593 * 67")... -```output -2518731 - -答案: 2518731 - -问题: 37593的五次方根是多少? -```text -37593**(1/5) -``` -...numexpr.evaluate("37593**(1/5)")... -```output -8.222831614237718 - -答案: 8.222831614237718 - - -问题: 2的平方是多少? -```text -2 ** 2 -``` -...numexpr.evaluate("2 ** 2")... -```output -4 - -答案: 4 - - -现在,这是我的问题: -问题: {question} -""" - -PROMPT = PromptTemplate( - input_variables=["question"], - template=_PROMPT_TEMPLATE, -) - - -class CalculatorInput(BaseModel): - query: str = Field() - -def calculate(query: str): - model = model_container.MODEL - llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT) - ans = llm_math.run(query) - return ans - -if __name__ == "__main__": - result = calculate("2的三次方") - print("答案:",result) - - - diff --git a/server/agent/tools/search_internet.py b/server/agent/tools/search_internet.py deleted file mode 100644 index 48a8a629..00000000 --- a/server/agent/tools/search_internet.py +++ /dev/null @@ -1,37 +0,0 @@ -import json -from server.chat.search_engine_chat import search_engine_chat -from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS -import asyncio -from server.agent import model_container -from pydantic import BaseModel, Field - -async def search_engine_iter(query: str): - response = await search_engine_chat(query=query, - search_engine_name="bing", # 这里切换搜索引擎 - model_name=model_container.MODEL.model_name, - temperature=0.01, # Agent 搜索互联网的时候,温度设置为0.01 - history=[], - top_k = VECTOR_SEARCH_TOP_K, - max_tokens= MAX_TOKENS, - prompt_name = "default", - stream=False) - - contents = "" - - async for data in response.body_iterator: # 这里的data是一个json字符串 - data = json.loads(data) - contents = data["answer"] - docs = data["docs"] - - return contents - -def search_internet(query: str): - return asyncio.run(search_engine_iter(query)) - -class SearchInternetInput(BaseModel): - location: str = Field(description="Query for Internet search") - - -if __name__ == "__main__": - result = search_internet("今天星期几") - print("答案:",result) diff --git a/server/agent/tools/search_knowledgebase_complex.py b/server/agent/tools/search_knowledgebase_complex.py deleted file mode 100644 index af4d9116..00000000 --- a/server/agent/tools/search_knowledgebase_complex.py +++ /dev/null @@ -1,287 +0,0 @@ -from __future__ import annotations -import json -import re -import warnings -from typing import Dict -from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun -from langchain.chains.llm import LLMChain -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel -from typing import List, Any, Optional -from langchain.prompts import PromptTemplate -from server.chat.knowledge_base_chat import knowledge_base_chat -from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS -import asyncio -from server.agent import model_container -from pydantic import BaseModel, Field - -async def search_knowledge_base_iter(database: str, query: str) -> str: - response = await knowledge_base_chat(query=query, - knowledge_base_name=database, - model_name=model_container.MODEL.model_name, - temperature=0.01, - history=[], - top_k=VECTOR_SEARCH_TOP_K, - max_tokens=MAX_TOKENS, - prompt_name="default", - score_threshold=SCORE_THRESHOLD, - stream=False) - - contents = "" - async for data in response.body_iterator: # 这里的data是一个json字符串 - data = json.loads(data) - contents += data["answer"] - docs = data["docs"] - return contents - - -async def search_knowledge_multiple(queries) -> List[str]: - # queries 应该是一个包含多个 (database, query) 元组的列表 - tasks = [search_knowledge_base_iter(database, query) for database, query in queries] - results = await asyncio.gather(*tasks) - # 结合每个查询结果,并在每个查询结果前添加一个自定义的消息 - combined_results = [] - for (database, _), result in zip(queries, results): - message = f"\n查询到 {database} 知识库的相关信息:\n{result}" - combined_results.append(message) - - return combined_results - - -def search_knowledge(queries) -> str: - responses = asyncio.run(search_knowledge_multiple(queries)) - # 输出每个整合的查询结果 - contents = "" - for response in responses: - contents += response + "\n\n" - return contents - - -_PROMPT_TEMPLATE = """ -用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。 - -对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。 - -例子: - -robotic,机器人男女比例是多少 -bigdata,大数据的就业情况如何 - - -这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考 - - -{database_names} - -你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。 -不要输出中文的逗号,不要输出引号。 - -Question: ${{用户的问题}} - -```text -${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}} - -```output -数据库查询的结果 - -现在,我们开始作答 -问题: {question} -""" - -PROMPT = PromptTemplate( - input_variables=["question", "database_names"], - template=_PROMPT_TEMPLATE, -) - - -class LLMKnowledgeChain(LLMChain): - llm_chain: LLMChain - llm: Optional[BaseLanguageModel] = None - """[Deprecated] LLM wrapper to use.""" - prompt: BasePromptTemplate = PROMPT - """[Deprecated] Prompt to use to translate to python if necessary.""" - database_names: Dict[str, str] = None - input_key: str = "question" #: :meta private: - output_key: str = "answer" #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: - if "llm" in values: - warnings.warn( - "Directly instantiating an LLMKnowledgeChain with an llm is deprecated. " - "Please instantiate with llm_chain argument or using the from_llm " - "class method." - ) - if "llm_chain" not in values and values["llm"] is not None: - prompt = values.get("prompt", PROMPT) - values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) - return values - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Expect output key. - - :meta private: - """ - return [self.output_key] - - def _evaluate_expression(self, queries) -> str: - try: - output = search_knowledge(queries) - except Exception as e: - output = "输入的信息有误或不存在知识库,错误信息如下:\n" - return output + str(e) - return output - - def _process_llm_result( - self, - llm_output: str, - run_manager: CallbackManagerForChainRun - ) -> Dict[str, str]: - - run_manager.on_text(llm_output, color="green", verbose=self.verbose) - - llm_output = llm_output.strip() - # text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) - text_match = re.search(r"```text(.*)", llm_output, re.DOTALL) - if text_match: - expression = text_match.group(1).strip() - cleaned_input_str = (expression.replace("\"", "").replace("“", ""). - replace("”", "").replace("```", "").strip()) - lines = cleaned_input_str.split("\n") - # 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表 - - try: - queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines] - except: - queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines] - run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose) - output = self._evaluate_expression(queries) - run_manager.on_text("\nAnswer: ", verbose=self.verbose) - run_manager.on_text(output, color="yellow", verbose=self.verbose) - answer = "Answer: " + output - elif llm_output.startswith("Answer:"): - answer = llm_output - elif "Answer:" in llm_output: - answer = llm_output.split("Answer:")[-1] - else: - return {self.output_key: f"输入的格式不对:\n {llm_output}"} - return {self.output_key: answer} - - async def _aprocess_llm_result( - self, - llm_output: str, - run_manager: AsyncCallbackManagerForChainRun, - ) -> Dict[str, str]: - await run_manager.on_text(llm_output, color="green", verbose=self.verbose) - llm_output = llm_output.strip() - text_match = re.search(r"```text(.*)", llm_output, re.DOTALL) - if text_match: - - expression = text_match.group(1).strip() - cleaned_input_str = ( - expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip()) - lines = cleaned_input_str.split("\n") - try: - queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines] - except: - queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines] - await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", - verbose=self.verbose) - - output = self._evaluate_expression(queries) - await run_manager.on_text("\nAnswer: ", verbose=self.verbose) - await run_manager.on_text(output, color="yellow", verbose=self.verbose) - answer = "Answer: " + output - elif llm_output.startswith("Answer:"): - answer = llm_output - elif "Answer:" in llm_output: - answer = "Answer: " + llm_output.split("Answer:")[-1] - else: - raise ValueError(f"unknown format from LLM: {llm_output}") - return {self.output_key: answer} - - def _call( - self, - inputs: Dict[str, str], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - _run_manager.on_text(inputs[self.input_key]) - self.database_names = model_container.DATABASE - data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()]) - llm_output = self.llm_chain.predict( - database_names=data_formatted_str, - question=inputs[self.input_key], - stop=["```output"], - callbacks=_run_manager.get_child(), - ) - return self._process_llm_result(llm_output, _run_manager) - - async def _acall( - self, - inputs: Dict[str, str], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() - await _run_manager.on_text(inputs[self.input_key]) - self.database_names = model_container.DATABASE - data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()]) - llm_output = await self.llm_chain.apredict( - database_names=data_formatted_str, - question=inputs[self.input_key], - stop=["```output"], - callbacks=_run_manager.get_child(), - ) - return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager) - - @property - def _chain_type(self) -> str: - return "llm_knowledge_chain" - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - prompt: BasePromptTemplate = PROMPT, - **kwargs: Any, - ) -> LLMKnowledgeChain: - llm_chain = LLMChain(llm=llm, prompt=prompt) - return cls(llm_chain=llm_chain, **kwargs) - - -def search_knowledgebase_complex(query: str): - model = model_container.MODEL - llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT) - ans = llm_knowledge.run(query) - return ans - -class KnowledgeSearchInput(BaseModel): - location: str = Field(description="The query to be searched") - -if __name__ == "__main__": - result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别") - print(result) - -# 这是一个正常的切割 -# queries = [ -# ("bigdata", "大数据专业的男女比例"), -# ("robotic", "机器人专业的优势") -# ] -# result = search_knowledge(queries) -# print(result) diff --git a/server/agent/tools/search_knowledgebase_once.py b/server/agent/tools/search_knowledgebase_once.py deleted file mode 100644 index c9a2d7b5..00000000 --- a/server/agent/tools/search_knowledgebase_once.py +++ /dev/null @@ -1,234 +0,0 @@ -from __future__ import annotations -import re -import warnings -from typing import Dict - -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.chains.llm import LLMChain -from langchain.pydantic_v1 import Extra, root_validator -from langchain.schema import BasePromptTemplate -from langchain.schema.language_model import BaseLanguageModel -from typing import List, Any, Optional -from langchain.prompts import PromptTemplate -import sys -import os -import json - -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -from server.chat.knowledge_base_chat import knowledge_base_chat -from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS - -import asyncio -from server.agent import model_container -from pydantic import BaseModel, Field - -async def search_knowledge_base_iter(database: str, query: str): - response = await knowledge_base_chat(query=query, - knowledge_base_name=database, - model_name=model_container.MODEL.model_name, - temperature=0.01, - history=[], - top_k=VECTOR_SEARCH_TOP_K, - max_tokens=MAX_TOKENS, - prompt_name="knowledge_base_chat", - score_threshold=SCORE_THRESHOLD, - stream=False) - - contents = "" - async for data in response.body_iterator: # 这里的data是一个json字符串 - data = json.loads(data) - contents += data["answer"] - docs = data["docs"] - return contents - - -_PROMPT_TEMPLATE = """ -用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考 -Question: ${{用户的问题}} -这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能: - -{database_names} - -你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。 -```text -${{知识库的名称}} -``` -```output -数据库查询的结果 -``` -答案: ${{答案}} - -现在,这是我的问题: -问题: {question} - -""" -PROMPT = PromptTemplate( - input_variables=["question", "database_names"], - template=_PROMPT_TEMPLATE, -) - - -class LLMKnowledgeChain(LLMChain): - llm_chain: LLMChain - llm: Optional[BaseLanguageModel] = None - """[Deprecated] LLM wrapper to use.""" - prompt: BasePromptTemplate = PROMPT - """[Deprecated] Prompt to use to translate to python if necessary.""" - database_names: Dict[str, str] = model_container.DATABASE - input_key: str = "question" #: :meta private: - output_key: str = "answer" #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @root_validator(pre=True) - def raise_deprecation(cls, values: Dict) -> Dict: - if "llm" in values: - warnings.warn( - "Directly instantiating an LLMKnowledgeChain with an llm is deprecated. " - "Please instantiate with llm_chain argument or using the from_llm " - "class method." - ) - if "llm_chain" not in values and values["llm"] is not None: - prompt = values.get("prompt", PROMPT) - values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) - return values - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Expect output key. - - :meta private: - """ - return [self.output_key] - - def _evaluate_expression(self, dataset, query) -> str: - try: - output = asyncio.run(search_knowledge_base_iter(dataset, query)) - except Exception as e: - output = "输入的信息有误或不存在知识库" - return output - return output - - def _process_llm_result( - self, - llm_output: str, - llm_input: str, - run_manager: CallbackManagerForChainRun - ) -> Dict[str, str]: - - run_manager.on_text(llm_output, color="green", verbose=self.verbose) - - llm_output = llm_output.strip() - text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) - if text_match: - database = text_match.group(1).strip() - output = self._evaluate_expression(database, llm_input) - run_manager.on_text("\nAnswer: ", verbose=self.verbose) - run_manager.on_text(output, color="yellow", verbose=self.verbose) - answer = "Answer: " + output - elif llm_output.startswith("Answer:"): - answer = llm_output - elif "Answer:" in llm_output: - answer = "Answer: " + llm_output.split("Answer:")[-1] - else: - return {self.output_key: f"输入的格式不对: {llm_output}"} - return {self.output_key: answer} - - async def _aprocess_llm_result( - self, - llm_output: str, - run_manager: AsyncCallbackManagerForChainRun, - ) -> Dict[str, str]: - await run_manager.on_text(llm_output, color="green", verbose=self.verbose) - llm_output = llm_output.strip() - text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) - if text_match: - expression = text_match.group(1) - output = self._evaluate_expression(expression) - await run_manager.on_text("\nAnswer: ", verbose=self.verbose) - await run_manager.on_text(output, color="yellow", verbose=self.verbose) - answer = "Answer: " + output - elif llm_output.startswith("Answer:"): - answer = llm_output - elif "Answer:" in llm_output: - answer = "Answer: " + llm_output.split("Answer:")[-1] - else: - raise ValueError(f"unknown format from LLM: {llm_output}") - return {self.output_key: answer} - - def _call( - self, - inputs: Dict[str, str], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - _run_manager.on_text(inputs[self.input_key]) - data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()]) - llm_output = self.llm_chain.predict( - database_names=data_formatted_str, - question=inputs[self.input_key], - stop=["```output"], - callbacks=_run_manager.get_child(), - ) - return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager) - - async def _acall( - self, - inputs: Dict[str, str], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() - await _run_manager.on_text(inputs[self.input_key]) - data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()]) - llm_output = await self.llm_chain.apredict( - database_names=data_formatted_str, - question=inputs[self.input_key], - stop=["```output"], - callbacks=_run_manager.get_child(), - ) - return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager) - - @property - def _chain_type(self) -> str: - return "llm_knowledge_chain" - - @classmethod - def from_llm( - cls, - llm: BaseLanguageModel, - prompt: BasePromptTemplate = PROMPT, - **kwargs: Any, - ) -> LLMKnowledgeChain: - llm_chain = LLMChain(llm=llm, prompt=prompt) - return cls(llm_chain=llm_chain, **kwargs) - - -def search_knowledgebase_once(query: str): - model = model_container.MODEL - llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT) - ans = llm_knowledge.run(query) - return ans - - -class KnowledgeSearchInput(BaseModel): - location: str = Field(description="The query to be searched") - - -if __name__ == "__main__": - result = search_knowledgebase_once("大数据的男女比例") - print(result) diff --git a/server/agent/tools/search_knowledgebase_simple.py b/server/agent/tools/search_knowledgebase_simple.py deleted file mode 100644 index 65d7df3b..00000000 --- a/server/agent/tools/search_knowledgebase_simple.py +++ /dev/null @@ -1,32 +0,0 @@ -from server.chat.knowledge_base_chat import knowledge_base_chat -from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS -import json -import asyncio -from server.agent import model_container - -async def search_knowledge_base_iter(database: str, query: str) -> str: - response = await knowledge_base_chat(query=query, - knowledge_base_name=database, - model_name=model_container.MODEL.model_name, - temperature=0.01, - history=[], - top_k=VECTOR_SEARCH_TOP_K, - max_tokens=MAX_TOKENS, - prompt_name="knowledge_base_chat", - score_threshold=SCORE_THRESHOLD, - stream=False) - - contents = "" - async for data in response.body_iterator: # 这里的data是一个json字符串 - data = json.loads(data) - contents = data["answer"] - docs = data["docs"] - return contents - -def search_knowledgebase_simple(query: str): - return asyncio.run(search_knowledge_base_iter(query)) - - -if __name__ == "__main__": - result = search_knowledgebase_simple("大数据男女比例") - print("答案:",result) \ No newline at end of file diff --git a/server/agent/tools_factory/__init__.py b/server/agent/tools_factory/__init__.py new file mode 100644 index 00000000..461ce311 --- /dev/null +++ b/server/agent/tools_factory/__init__.py @@ -0,0 +1,8 @@ +from .search_local_knowledgebase import search_local_knowledgebase, SearchKnowledgeInput +from .calculate import calculate, CalculatorInput +from .weather_check import weather_check, WeatherInput +from .shell import shell, ShellInput +from .search_internet import search_internet, SearchInternetInput +from .wolfram import wolfram, WolframInput +from .search_youtube import search_youtube, YoutubeInput +from .arxiv import arxiv, ArxivInput diff --git a/server/agent/tools/arxiv.py b/server/agent/tools_factory/arxiv.py similarity index 100% rename from server/agent/tools/arxiv.py rename to server/agent/tools_factory/arxiv.py diff --git a/server/agent/tools_factory/calculate.py b/server/agent/tools_factory/calculate.py new file mode 100644 index 00000000..a47d65ca --- /dev/null +++ b/server/agent/tools_factory/calculate.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel, Field + +def calculate(a: float, b: float, operator: str) -> float: + if operator == "+": + return a + b + elif operator == "-": + return a - b + elif operator == "*": + return a * b + elif operator == "/": + if b != 0: + return a / b + else: + return float('inf') # 防止除以零 + elif operator == "^": + return a ** b + else: + raise ValueError("Unsupported operator") + +class CalculatorInput(BaseModel): + a: float = Field(description="first number") + b: float = Field(description="second number") + operator: str = Field(description="operator to use (e.g., +, -, *, /, ^)") diff --git a/server/agent/tools_factory/search_internet.py b/server/agent/tools_factory/search_internet.py new file mode 100644 index 00000000..68a5c6b1 --- /dev/null +++ b/server/agent/tools_factory/search_internet.py @@ -0,0 +1,99 @@ +from pydantic import BaseModel, Field +from langchain.utilities.bing_search import BingSearchAPIWrapper +from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper +from configs import TOOL_CONFIG +from langchain.text_splitter import RecursiveCharacterTextSplitter +from typing import List, Dict +from langchain.docstore.document import Document +from strsimpy.normalized_levenshtein import NormalizedLevenshtein +from markdownify import markdownify + + +def bing_search(text, config): + search = BingSearchAPIWrapper(bing_subscription_key=config["bing_key"], + bing_search_url=config["bing_search_url"]) + return search.results(text, config["result_len"]) + + +def duckduckgo_search(text, config): + search = DuckDuckGoSearchAPIWrapper() + return search.results(text, config["result_len"]) + + +def metaphor_search( + text: str, + config: dict, +) -> List[Dict]: + from metaphor_python import Metaphor + client = Metaphor(config["metaphor_api_key"]) + search = client.search(text, num_results=config["result_len"], use_autoprompt=True) + contents = search.get_contents().contents + for x in contents: + x.extract = markdownify(x.extract) + if config["split_result"]: + docs = [Document(page_content=x.extract, + metadata={"link": x.url, "title": x.title}) + for x in contents] + text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "], + chunk_size=config["chunk_size"], + chunk_overlap=config["chunk_overlap"]) + splitted_docs = text_splitter.split_documents(docs) + if len(splitted_docs) > config["result_len"]: + normal = NormalizedLevenshtein() + for x in splitted_docs: + x.metadata["score"] = normal.similarity(text, x.page_content) + splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True) + splitted_docs = splitted_docs[:config["result_len"]] + + docs = [{"snippet": x.page_content, + "link": x.metadata["link"], + "title": x.metadata["title"]} + for x in splitted_docs] + else: + docs = [{"snippet": x.extract, + "link": x.url, + "title": x.title} + for x in contents] + + return docs + + +SEARCH_ENGINES = {"bing": bing_search, + "duckduckgo": duckduckgo_search, + "metaphor": metaphor_search, + } + + +def search_result2docs(search_results): + docs = [] + for result in search_results: + doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "", + metadata={"source": result["link"] if "link" in result.keys() else "", + "filename": result["title"] if "title" in result.keys() else ""}) + docs.append(doc) + return docs + + +def search_engine(query: str, + config: dict): + search_engine_use = SEARCH_ENGINES[config["search_engine_name"]] + results = search_engine_use(text=query, + config=config["search_engine_config"][ + config["search_engine_name"]]) + docs = search_result2docs(results) + context = "" + docs = [ + f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" + for inum, doc in enumerate(docs) + ] + + for doc in docs: + context += doc + "\n" + return context +def search_internet(query: str): + tool_config = TOOL_CONFIG["search_internet"] + return search_engine(query=query, config=tool_config) + +class SearchInternetInput(BaseModel): + query: str = Field(description="Query for Internet search") + diff --git a/server/agent/tools_factory/search_local_knowledgebase.py b/server/agent/tools_factory/search_local_knowledgebase.py new file mode 100644 index 00000000..0bc7803c --- /dev/null +++ b/server/agent/tools_factory/search_local_knowledgebase.py @@ -0,0 +1,40 @@ +from urllib.parse import urlencode + +from pydantic import BaseModel, Field + +from server.knowledge_base.kb_doc_api import search_docs +from configs import TOOL_CONFIG + + +def search_knowledgebase(query: str, database: str, config: dict): + docs = search_docs( + query=query, + knowledge_base_name=database, + top_k=config["top_k"], + score_threshold=config["score_threshold"]) + context = "" + source_documents = [] + for inum, doc in enumerate(docs): + filename = doc.metadata.get("source") + parameters = urlencode({"knowledge_base_name": database, "file_name": filename}) + url = f"download_doc?" + parameters + text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" + source_documents.append(text) + + if len(source_documents) == 0: + context= "没有找到相关文档,请更换关键词重试" + else: + for doc in source_documents: + context += doc + "\n" + + return context + + +class SearchKnowledgeInput(BaseModel): + database: str = Field(description="Database for Knowledge Search") + query: str = Field(description="Query for Knowledge Search") + + +def search_local_knowledgebase(database: str, query: str): + tool_config = TOOL_CONFIG["search_local_knowledgebase"] + return search_knowledgebase(query=query, database=database, config=tool_config) diff --git a/server/agent/tools/search_youtube.py b/server/agent/tools_factory/search_youtube.py similarity index 80% rename from server/agent/tools/search_youtube.py rename to server/agent/tools_factory/search_youtube.py index f02b3625..57049897 100644 --- a/server/agent/tools/search_youtube.py +++ b/server/agent/tools_factory/search_youtube.py @@ -6,4 +6,4 @@ def search_youtube(query: str): return tool.run(tool_input=query) class YoutubeInput(BaseModel): - location: str = Field(description="Query for Videos search") \ No newline at end of file + query: str = Field(description="Query for Videos search") \ No newline at end of file diff --git a/server/agent/tools/shell.py b/server/agent/tools_factory/shell.py similarity index 72% rename from server/agent/tools/shell.py rename to server/agent/tools_factory/shell.py index db074154..01046559 100644 --- a/server/agent/tools/shell.py +++ b/server/agent/tools_factory/shell.py @@ -6,4 +6,4 @@ def shell(query: str): return tool.run(tool_input=query) class ShellInput(BaseModel): - query: str = Field(description="一个能在Linux命令行运行的Shell命令") \ No newline at end of file + query: str = Field(description="The command to execute") \ No newline at end of file diff --git a/server/agent/tools_factory/tools_registry.py b/server/agent/tools_factory/tools_registry.py new file mode 100644 index 00000000..61c8451f --- /dev/null +++ b/server/agent/tools_factory/tools_registry.py @@ -0,0 +1,59 @@ +from langchain_core.tools import StructuredTool +from server.agent.tools_factory import * +from configs import KB_INFO + +template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool." +KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) +template_knowledge = template.format(KB_info=KB_info_str) + +all_tools = [ + StructuredTool.from_function( + func=calculate, + name="calculate", + description="Useful for when you need to answer questions about simple calculations", + args_schema=CalculatorInput, + ), + StructuredTool.from_function( + func=arxiv, + name="arxiv", + description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.", + args_schema=ArxivInput, + ), + StructuredTool.from_function( + func=shell, + name="shell", + description="Use Shell to execute Linux commands", + args_schema=ShellInput, + ), + StructuredTool.from_function( + func=wolfram, + name="wolfram", + description="Useful for when you need to calculate difficult formulas", + args_schema=WolframInput, + ), + StructuredTool.from_function( + func=search_youtube, + name="search_youtube", + description="use this tools_factory to search youtube videos", + args_schema=YoutubeInput, + ), + StructuredTool.from_function( + func=weather_check, + name="weather_check", + description="Use this tool to check the weather at a specific location", + args_schema=WeatherInput, + ), + StructuredTool.from_function( + func=search_internet, + name="search_internet", + description="Use this tool to use bing search engine to search the internet and get information", + args_schema=SearchInternetInput, + ), + StructuredTool.from_function( + func=search_local_knowledgebase, + name="search_local_knowledgebase", + description=template_knowledge, + args_schema=SearchKnowledgeInput, + ), +] + diff --git a/server/agent/tools/weather_check.py b/server/agent/tools_factory/weather_check.py similarity index 76% rename from server/agent/tools/weather_check.py rename to server/agent/tools_factory/weather_check.py index 20f009a5..954d74ee 100644 --- a/server/agent/tools/weather_check.py +++ b/server/agent/tools_factory/weather_check.py @@ -1,10 +1,8 @@ """ -更简单的单参数输入工具实现,用于查询现在天气的情况 +简单的单参数输入工具实现,用于查询现在天气的情况 """ from pydantic import BaseModel, Field import requests -from configs.kb_config import SENIVERSE_API_KEY - def weather(location: str, api_key: str): url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c" @@ -21,9 +19,7 @@ def weather(location: str, api_key: str): f"Failed to retrieve weather: {response.status_code}") -def weathercheck(location: str): - return weather(location, SENIVERSE_API_KEY) - - +def weather_check(location: str): + return weather(location, "S8vrB4U_-c5mvAMiK") class WeatherInput(BaseModel): - location: str = Field(description="City name,include city and county") + location: str = Field(description="City name,include city and county,like '厦门'") diff --git a/server/agent/tools/wolfram.py b/server/agent/tools_factory/wolfram.py similarity index 84% rename from server/agent/tools/wolfram.py rename to server/agent/tools_factory/wolfram.py index c322da18..6bdbe10e 100644 --- a/server/agent/tools/wolfram.py +++ b/server/agent/tools_factory/wolfram.py @@ -8,4 +8,4 @@ def wolfram(query: str): return ans class WolframInput(BaseModel): - location: str = Field(description="需要运算的具体问题") \ No newline at end of file + formula: str = Field(description="The formula to be calculated") \ No newline at end of file diff --git a/server/agent/tools_select.py b/server/agent/tools_select.py deleted file mode 100644 index 237c20b6..00000000 --- a/server/agent/tools_select.py +++ /dev/null @@ -1,55 +0,0 @@ -from langchain.tools import Tool -from server.agent.tools import * - -tools = [ - Tool.from_function( - func=calculate, - name="calculate", - description="Useful for when you need to answer questions about simple calculations", - args_schema=CalculatorInput, - ), - Tool.from_function( - func=arxiv, - name="arxiv", - description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.", - args_schema=ArxivInput, - ), - Tool.from_function( - func=weathercheck, - name="weather_check", - description="", - args_schema=WeatherInput, - ), - Tool.from_function( - func=shell, - name="shell", - description="Use Shell to execute Linux commands", - args_schema=ShellInput, - ), - Tool.from_function( - func=search_knowledgebase_complex, - name="search_knowledgebase_complex", - description="Use Use this tool to search local knowledgebase and get information", - args_schema=KnowledgeSearchInput, - ), - Tool.from_function( - func=search_internet, - name="search_internet", - description="Use this tool to use bing search engine to search the internet", - args_schema=SearchInternetInput, - ), - Tool.from_function( - func=wolfram, - name="Wolfram", - description="Useful for when you need to calculate difficult formulas", - args_schema=WolframInput, - ), - Tool.from_function( - func=search_youtube, - name="search_youtube", - description="use this tools to search youtube videos", - args_schema=YoutubeInput, - ), -] - -tool_names = [tool.name for tool in tools] diff --git a/server/api.py b/server/api.py index 6e4aa437..3c4d04df 100644 --- a/server/api.py +++ b/server/api.py @@ -13,13 +13,12 @@ from fastapi import Body from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat.chat import chat -from server.chat.search_engine_chat import search_engine_chat from server.chat.completion import completion from server.chat.feedback import chat_feedback from server.embeddings_api import embed_texts_endpoint from server.llm_api import (list_running_models, list_config_models, change_llm_model, stop_llm_model, - get_model_config, list_search_engines) + get_model_config) from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs, get_prompt_template) from typing import List, Literal @@ -63,11 +62,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): summary="与llm模型对话(通过LLMChain)", )(chat) - app.post("/chat/search_engine_chat", - tags=["Chat"], - summary="与搜索引擎对话", - )(search_engine_chat) - app.post("/chat/feedback", tags=["Chat"], summary="返回llm模型对话评分", @@ -110,16 +104,12 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): summary="获取服务器原始配置信息", )(get_server_configs) - app.post("/server/list_search_engines", - tags=["Server State"], - summary="获取服务器支持的搜索引擎", - )(list_search_engines) @app.post("/server/get_prompt_template", tags=["Server State"], summary="获取服务区配置的 prompt 模板") def get_server_prompt_template( - type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat,search_engine_chat,agent_chat"), + type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat"), name: str = Body("default", description="模板名称"), ) -> str: return get_prompt_template(type=type, name=name) @@ -139,7 +129,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): def mount_knowledge_routes(app: FastAPI): from server.chat.knowledge_base_chat import knowledge_base_chat from server.chat.file_chat import upload_temp_docs, file_chat - from server.chat.agent_chat import agent_chat from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, update_docs, download_doc, recreate_vector_store, @@ -155,11 +144,6 @@ def mount_knowledge_routes(app: FastAPI): summary="文件对话" )(file_chat) - app.post("/chat/agent_chat", - tags=["Chat"], - summary="与agent对话")(agent_chat) - - # Tag: Knowledge Base Management app.get("/knowledge_base/list_knowledge_bases", tags=["Knowledge Base Management"], response_model=ListResponse, diff --git a/server/agent/callbacks.py b/server/callback_handler/agent_callback_handler.py similarity index 53% rename from server/agent/callbacks.py rename to server/callback_handler/agent_callback_handler.py index 0935f9dc..dad07a79 100644 --- a/server/agent/callbacks.py +++ b/server/callback_handler/agent_callback_handler.py @@ -1,13 +1,11 @@ from __future__ import annotations from uuid import UUID -from langchain.callbacks import AsyncIteratorCallbackHandler import json -import asyncio -from typing import Any, Dict, List, Optional - from langchain.schema import AgentFinish, AgentAction -from langchain.schema.output import LLMResult - +import asyncio +from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional +from langchain_core.outputs import LLMResult +from langchain.callbacks.base import AsyncCallbackHandler def dumps(obj: Dict) -> str: return json.dumps(obj, ensure_ascii=False) @@ -23,7 +21,7 @@ class Status: tool_finish: int = 7 -class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): +class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler): def __init__(self): super().__init__() self.queue = asyncio.Queue() @@ -31,40 +29,29 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.cur_tool = {} self.out = True - async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, - parent_run_id: UUID | None = None, tags: List[str] | None = None, - metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None: - - # 对于截断不能自理的大模型,我来帮他截断 - stop_words = ["Observation:", "Thought","\"","(", "\n","\t"] - for stop_word in stop_words: - index = input_str.find(stop_word) - if index != -1: - input_str = input_str[:index] - break - - self.cur_tool = { - "tool_name": serialized["name"], - "input_str": input_str, - "output_str": "", - "status": Status.agent_action, - "run_id": run_id.hex, - "llm_token": "", - "final_answer": "", - "error": "", - } - # print("\nInput Str:",self.cur_tool["input_str"]) - self.queue.put_nowait(dumps(self.cur_tool)) - - async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, - tags: List[str] | None = None, **kwargs: Any) -> None: - self.out = True ## 重置输出 - self.cur_tool.update( - status=Status.tool_finish, - output_str=output.replace("Answer:", ""), - ) - self.queue.put_nowait(dumps(self.cur_tool)) + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + print("on_tool_start") + async def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + print("on_tool_end") async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: self.cur_tool.update( @@ -73,23 +60,6 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): ) self.queue.put_nowait(dumps(self.cur_tool)) - # async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - # if "Action" in token: ## 减少重复输出 - # before_action = token.split("Action")[0] - # self.cur_tool.update( - # status=Status.running, - # llm_token=before_action + "\n", - # ) - # self.queue.put_nowait(dumps(self.cur_tool)) - # - # self.out = False - # - # if token and self.out: - # self.cur_tool.update( - # status=Status.running, - # llm_token=token, - # ) - # self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: special_tokens = ["Action", "<|observation|>"] for stoken in special_tokens: @@ -103,7 +73,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.out = False break - if token and self.out: + if token is not None and token != "" and self.out: self.cur_tool.update( status=Status.running, llm_token=token, @@ -116,16 +86,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): llm_token="", ) self.queue.put_nowait(dumps(self.cur_tool)) + async def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, + self, + serialized: Dict[str, Any], + messages: List[List], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> None: self.cur_tool.update( status=Status.start, @@ -136,8 +107,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.cur_tool.update( status=Status.complete, - llm_token="\n", + llm_token="", ) + self.out = True self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None: @@ -147,15 +119,44 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): ) self.queue.put_nowait(dumps(self.cur_tool)) + async def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + self.cur_tool.update( + status=Status.agent_action, + tool_name=action.tool, + tool_input=action.tool_input, + ) + self.queue.put_nowait(dumps(self.cur_tool)) async def on_agent_finish( self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - # 返回最终答案 self.cur_tool.update( status=Status.agent_finish, - final_answer=finish.return_values["output"], + agent_finish=finish.return_values["output"], ) self.queue.put_nowait(dumps(self.cur_tool)) - self.cur_tool = {} + + async def aiter(self) -> AsyncIterator[str]: + while not self.queue.empty() or not self.done.is_set(): + done, other = await asyncio.wait( + [ + asyncio.ensure_future(self.queue.get()), + asyncio.ensure_future(self.done.wait()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + if other: + other.pop().cancel() + token_or_done = cast(Union[str, Literal[True]], done.pop().result()) + if token_or_done is True: + break + yield token_or_done diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py deleted file mode 100644 index 41bf5bab..00000000 --- a/server/chat/agent_chat.py +++ /dev/null @@ -1,178 +0,0 @@ -import json -import asyncio - -from fastapi import Body -from sse_starlette.sse import EventSourceResponse -from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL - -from langchain.chains import LLMChain -from langchain.memory import ConversationBufferWindowMemory -from langchain.agents import LLMSingleActionAgent, AgentExecutor -from typing import AsyncIterable, Optional, List - -from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template -from server.knowledge_base.kb_service.base import get_kb_details -from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent -from server.agent.tools_select import tools, tool_names -from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status -from server.chat.utils import History -from server.agent import model_container -from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate - - -async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), - history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", "content": "请使用知识库工具查询今天北京天气"}, - {"role": "assistant", - "content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]] - ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), - prompt_name: str = Body("default", - description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - ): - history = [History.from_data(h) for h in history] - - async def agent_chat_iterator( - query: str, - history: Optional[List[History]], - model_name: str = LLM_MODELS[0], - prompt_name: str = prompt_name, - ) -> AsyncIterable[str]: - nonlocal max_tokens - callback = CustomAsyncIteratorCallbackHandler() - if isinstance(max_tokens, int) and max_tokens <= 0: - max_tokens = None - - model = get_ChatOpenAI( - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - callbacks=[callback], - ) - - kb_list = {x["kb_name"]: x for x in get_kb_details()} - model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()} - - if Agent_MODEL: - model_agent = get_ChatOpenAI( - model_name=Agent_MODEL, - temperature=temperature, - max_tokens=max_tokens, - callbacks=[callback], - ) - model_container.MODEL = model_agent - else: - model_container.MODEL = model - - prompt_template = get_prompt_template("agent_chat", prompt_name) - prompt_template_agent = CustomPromptTemplate( - template=prompt_template, - tools=tools, - input_variables=["input", "intermediate_steps", "history"] - ) - output_parser = CustomOutputParser() - llm_chain = LLMChain(llm=model, prompt=prompt_template_agent) - memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2) - for message in history: - if message.role == 'user': - memory.chat_memory.add_user_message(message.content) - else: - memory.chat_memory.add_ai_message(message.content) - if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name: - agent_executor = initialize_glm3_agent( - llm=model, - tools=tools, - callback_manager=None, - prompt=prompt_template, - input_variables=["input", "intermediate_steps", "history"], - memory=memory, - verbose=True, - ) - else: - agent = LLMSingleActionAgent( - llm_chain=llm_chain, - output_parser=output_parser, - stop=["\nObservation:", "Observation"], - allowed_tools=tool_names, - ) - agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, - tools=tools, - verbose=True, - memory=memory, - ) - while True: - try: - task = asyncio.create_task(wrap_done( - agent_executor.acall(query, callbacks=[callback], include_run_info=True), - callback.done)) - break - except: - pass - - if stream: - async for chunk in callback.aiter(): - tools_use = [] - # Use server-sent-events to stream the response - data = json.loads(chunk) - if data["status"] == Status.start or data["status"] == Status.complete: - continue - elif data["status"] == Status.error: - tools_use.append("\n```\n") - tools_use.append("工具名称: " + data["tool_name"]) - tools_use.append("工具状态: " + "调用失败") - tools_use.append("错误信息: " + data["error"]) - tools_use.append("重新开始尝试") - tools_use.append("\n```\n") - yield json.dumps({"tools": tools_use}, ensure_ascii=False) - elif data["status"] == Status.tool_finish: - tools_use.append("\n```\n") - tools_use.append("工具名称: " + data["tool_name"]) - tools_use.append("工具状态: " + "调用成功") - tools_use.append("工具输入: " + data["input_str"]) - tools_use.append("工具输出: " + data["output_str"]) - tools_use.append("\n```\n") - yield json.dumps({"tools": tools_use}, ensure_ascii=False) - elif data["status"] == Status.agent_finish: - yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False) - else: - yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False) - - - else: - answer = "" - final_answer = "" - async for chunk in callback.aiter(): - data = json.loads(chunk) - if data["status"] == Status.start or data["status"] == Status.complete: - continue - if data["status"] == Status.error: - answer += "\n```\n" - answer += "工具名称: " + data["tool_name"] + "\n" - answer += "工具状态: " + "调用失败" + "\n" - answer += "错误信息: " + data["error"] + "\n" - answer += "\n```\n" - if data["status"] == Status.tool_finish: - answer += "\n```\n" - answer += "工具名称: " + data["tool_name"] + "\n" - answer += "工具状态: " + "调用成功" + "\n" - answer += "工具输入: " + data["input_str"] + "\n" - answer += "工具输出: " + data["output_str"] + "\n" - answer += "\n```\n" - if data["status"] == Status.agent_finish: - final_answer = data["final_answer"] - else: - answer += data["llm_token"] - - yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False) - await task - - return EventSourceResponse(agent_chat_iterator(query=query, - history=history, - model_name=model_name, - prompt_name=prompt_name), - ) diff --git a/server/chat/chat.py b/server/chat/chat.py index 5783829c..ccf0982b 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,20 +1,43 @@ from fastapi import Body -from sse_starlette.sse import EventSourceResponse -from configs import LLM_MODELS, TEMPERATURE +from fastapi.responses import StreamingResponse +from langchain.agents import initialize_agent, AgentType +from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnableBranch +from server.agent.agent_factory import initialize_glm3_agent +from server.agent.tools_factory.tools_registry import all_tools from server.utils import wrap_done, get_ChatOpenAI from langchain.chains import LLMChain -from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable +from typing import AsyncIterable, Dict import asyncio import json from langchain.prompts.chat import ChatPromptTemplate -from typing import List, Optional, Union +from typing import List, Union from server.chat.utils import History from langchain.prompts import PromptTemplate from server.utils import get_prompt_template from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.db.repository import add_message_to_db -from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler +from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler + + +def create_models_from_config(configs: dict = {}, callbacks: list = []): + models = {} + prompts = {} + for model_type, model_configs in configs.items(): + for model_name, params in model_configs.items(): + callback = callbacks if params.get('callbacks', False) else None + model_instance = get_ChatOpenAI( + model_name=model_name, + temperature=params.get('temperature', 0.5), + max_tokens=params.get('max_tokens', 1000), + callbacks=callback + ) + models[model_type] = model_instance + prompt_name = params.get('prompt_name', 'default') + prompt_template = get_prompt_template(type=model_type, name=prompt_name) + prompts[model_type] = prompt_template + return models, prompts async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), @@ -28,76 +51,106 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 {"role": "assistant", "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0), - max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), - # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), - prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + model_config: Dict = Body({}, description="LLM 模型配置。"), + tool_config: Dict = Body({}, description="工具配置"), ): async def chat_iterator() -> AsyncIterable[str]: - nonlocal history, max_tokens - callback = AsyncIteratorCallbackHandler() - callbacks = [callback] + nonlocal history memory = None + message_id = None + chat_prompt = None - # 负责保存llm response到message db - message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) - conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, - chat_type="llm_chat", - query=query) - callbacks.append(conversation_callback) + callback = CustomAsyncIteratorCallbackHandler() + callbacks = [callback] + models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config) - if isinstance(max_tokens, int) and max_tokens <= 0: - max_tokens = None + if conversation_id: + message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) + tools = [tool for tool in all_tools if tool.name in tool_config] - model = get_ChatOpenAI( - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - callbacks=callbacks, - ) - - if history: # 优先使用前端传入的历史消息 + if history: history = [History.from_data(h) for h in history] - prompt_template = get_prompt_template("llm_chat", prompt_name) - input_msg = History(role="user", content=prompt_template).to_msg_template(False) + input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) - elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息 - # 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量 - prompt = get_prompt_template("llm_chat", "with_history") - chat_prompt = PromptTemplate.from_template(prompt) - # 根据conversation_id 获取message 列表进而拼凑 memory - memory = ConversationBufferDBMemory(conversation_id=conversation_id, - llm=model, + elif conversation_id and history_len > 0: + memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"], message_limit=history_len) else: - prompt_template = get_prompt_template("llm_chat", prompt_name) - input_msg = History(role="user", content=prompt_template).to_msg_template(False) + input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_msg]) - chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory) - - # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - chain.acall({"input": query}), - callback.done), + chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory) + classifier_chain = ( + PromptTemplate.from_template(prompts["preprocess_model"]) + | models["preprocess_model"] + | StrOutputParser() ) - - if stream: - async for token in callback.aiter(): - # Use server-sent-events to stream the response - yield json.dumps( - {"text": token, "message_id": message_id}, - ensure_ascii=False) + if "chatglm3" in models["action_model"].model_name.lower(): + agent_executor = initialize_glm3_agent( + llm=models["action_model"], + tools=tools, + prompt=prompts["action_model"], + input_variables=["input", "intermediate_steps", "history"], + memory=memory, + callback_manager=BaseCallbackManager(handlers=callbacks), + verbose=True, + ) else: - answer = "" - async for token in callback.aiter(): - answer += token - yield json.dumps( - {"text": answer, "message_id": message_id}, - ensure_ascii=False) - + agent_executor = initialize_agent( + llm=models["action_model"], + tools=tools, + callbacks=callbacks, + agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, + memory=memory, + verbose=True, + ) + branch = RunnableBranch( + (lambda x: "1" in x["topic"].lower(), agent_executor), + chain + ) + full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) + task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done)) + if stream: + async for chunk in callback.aiter(): + data = json.loads(chunk) + if data["status"] == Status.start: + continue + elif data["status"] == Status.agent_action: + tool_info = { + "tool_name": data["tool_name"], + "tool_input": data["tool_input"] + } + yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False) + elif data["status"] == Status.agent_finish: + yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id}, + ensure_ascii=False) + else: + yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False) + else: + text = "" + agent_finish = "" + tool_info = None + async for chunk in callback.aiter(): + # Use server-sent-events to stream the response + data = json.loads(chunk) + if data["status"] == Status.agent_action: + tool_info = { + "tool_name": data["tool_name"], + "tool_input": data["tool_input"] + } + if data["status"] == Status.agent_finish: + agent_finish = data["agent_finish"] + else: + text += data["llm_token"] + if tool_info: + yield json.dumps( + {"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id}, + ensure_ascii=False) + else: + yield json.dumps( + {"text": text, "message_id": message_id}, + ensure_ascii=False) await task return EventSourceResponse(chat_iterator()) diff --git a/server/chat/completion.py b/server/chat/completion.py index acf17366..f93abce8 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -1,6 +1,6 @@ from fastapi import Body -from sse_starlette.sse import EventSourceResponse -from configs import LLM_MODELS, TEMPERATURE +from fastapi.responses import StreamingResponse +from configs import LLM_MODEL_CONFIG from server.utils import wrap_done, get_OpenAI from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler @@ -14,8 +14,8 @@ from server.utils import get_prompt_template async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), stream: bool = Body(False, description="流式输出"), echo: bool = Body(False, description="除了输出之外,还回显输入"), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), prompt_name: str = Body("default", @@ -24,7 +24,7 @@ async def completion(query: str = Body(..., description="用户输入", examples #todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 async def completion_iterator(query: str, - model_name: str = LLM_MODELS[0], + model_name: str = None, prompt_name: str = prompt_name, echo: bool = echo, ) -> AsyncIterable[str]: diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index a3bbdfc6..275371b5 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -1,7 +1,6 @@ from fastapi import Body, File, Form, UploadFile -from sse_starlette.sse import EventSourceResponse -from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) +from fastapi.responses import StreamingResponse +from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) from server.utils import (wrap_done, get_ChatOpenAI, BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool) from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool @@ -15,8 +14,6 @@ from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter from server.knowledge_base.utils import KnowledgeFile import json import os -from pathlib import Path - def _parse_files_in_thread( files: List[UploadFile], @@ -102,8 +99,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py deleted file mode 100644 index 60956b44..00000000 --- a/server/chat/knowledge_base_chat.py +++ /dev/null @@ -1,147 +0,0 @@ -from fastapi import Body, Request -from sse_starlette.sse import EventSourceResponse -from fastapi.concurrency import run_in_threadpool -from configs import (LLM_MODELS, - VECTOR_SEARCH_TOP_K, - SCORE_THRESHOLD, - TEMPERATURE, - USE_RERANKER, - RERANKER_MODEL, - RERANKER_MAX_LENGTH, - MODEL_PATH) -from server.utils import wrap_done, get_ChatOpenAI -from server.utils import BaseResponse, get_prompt_template -from langchain.chains import LLMChain -from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable, List, Optional -import asyncio -from langchain.prompts.chat import ChatPromptTemplate -from server.chat.utils import History -from server.knowledge_base.kb_service.base import KBServiceFactory -import json -from urllib.parse import urlencode -from server.knowledge_base.kb_doc_api import search_docs -from server.reranker.reranker import LangchainReranker -from server.utils import embedding_device -async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body( - SCORE_THRESHOLD, - description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", - ge=0, - le=2 - ), - history: List[History] = Body( - [], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body( - None, - description="限制LLM生成Token数量,默认None代表模型最大值" - ), - prompt_name: str = Body( - "default", - description="使用的prompt模板名称(在configs/prompt_config.py中配置)" - ), - request: Request = None, - ): - kb = KBServiceFactory.get_service_by_name(knowledge_base_name) - if kb is None: - return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - - history = [History.from_data(h) for h in history] - - async def knowledge_base_chat_iterator( - query: str, - top_k: int, - history: Optional[List[History]], - model_name: str = model_name, - prompt_name: str = prompt_name, - ) -> AsyncIterable[str]: - nonlocal max_tokens - callback = AsyncIteratorCallbackHandler() - if isinstance(max_tokens, int) and max_tokens <= 0: - max_tokens = None - - model = get_ChatOpenAI( - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - callbacks=[callback], - ) - docs = await run_in_threadpool(search_docs, - query=query, - knowledge_base_name=knowledge_base_name, - top_k=top_k, - score_threshold=score_threshold) - - # 加入reranker - if USE_RERANKER: - reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large") - print("-----------------model path------------------") - print(reranker_model_path) - reranker_model = LangchainReranker(top_n=top_k, - device=embedding_device(), - max_length=RERANKER_MAX_LENGTH, - model_name_or_path=reranker_model_path - ) - print(docs) - docs = reranker_model.compress_documents(documents=docs, - query=query) - print("---------after rerank------------------") - print(docs) - context = "\n".join([doc.page_content for doc in docs]) - - if len(docs) == 0: # 如果没有找到相关文档,使用empty模板 - prompt_template = get_prompt_template("knowledge_base_chat", "empty") - else: - prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) - input_msg = History(role="user", content=prompt_template).to_msg_template(False) - chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_template() for i in history] + [input_msg]) - - chain = LLMChain(prompt=chat_prompt, llm=model) - - # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - chain.acall({"context": context, "question": query}), - callback.done), - ) - - source_documents = [] - for inum, doc in enumerate(docs): - filename = doc.metadata.get("source") - parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename}) - base_url = request.base_url - url = f"{base_url}knowledge_base/download_doc?" + parameters - text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" - source_documents.append(text) - - if len(source_documents) == 0: # 没有找到相关文档 - source_documents.append(f"未找到相关文档,该回答为大模型自身能力解答!") - - if stream: - async for token in callback.aiter(): - # Use server-sent-events to stream the response - yield json.dumps({"answer": token}, ensure_ascii=False) - yield json.dumps({"docs": source_documents}, ensure_ascii=False) - else: - answer = "" - async for token in callback.aiter(): - answer += token - yield json.dumps({"answer": answer, - "docs": source_documents}, - ensure_ascii=False) - await task - - return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name)) - diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py deleted file mode 100644 index 42bef3c2..00000000 --- a/server/chat/search_engine_chat.py +++ /dev/null @@ -1,208 +0,0 @@ -from langchain.utilities.bing_search import BingSearchAPIWrapper -from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper -from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY, - LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE) -from langchain.chains import LLMChain -from langchain.callbacks import AsyncIteratorCallbackHandler - -from langchain.prompts.chat import ChatPromptTemplate -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.docstore.document import Document -from fastapi import Body -from fastapi.concurrency import run_in_threadpool -from sse_starlette import EventSourceResponse -from server.utils import wrap_done, get_ChatOpenAI -from server.utils import BaseResponse, get_prompt_template -from server.chat.utils import History -from typing import AsyncIterable -import asyncio -import json -from typing import List, Optional, Dict -from strsimpy.normalized_levenshtein import NormalizedLevenshtein -from markdownify import markdownify - - -def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs): - if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY): - return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV", - "title": "env info is not found", - "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}] - search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY, - bing_search_url=BING_SEARCH_URL) - return search.results(text, result_len) - - -def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs): - search = DuckDuckGoSearchAPIWrapper() - return search.results(text, result_len) - - -def metaphor_search( - text: str, - result_len: int = SEARCH_ENGINE_TOP_K, - split_result: bool = False, - chunk_size: int = 500, - chunk_overlap: int = OVERLAP_SIZE, -) -> List[Dict]: - from metaphor_python import Metaphor - - if not METAPHOR_API_KEY: - return [] - - client = Metaphor(METAPHOR_API_KEY) - search = client.search(text, num_results=result_len, use_autoprompt=True) - contents = search.get_contents().contents - for x in contents: - x.extract = markdownify(x.extract) - - # metaphor 返回的内容都是长文本,需要分词再检索 - if split_result: - docs = [Document(page_content=x.extract, - metadata={"link": x.url, "title": x.title}) - for x in contents] - text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "], - chunk_size=chunk_size, - chunk_overlap=chunk_overlap) - splitted_docs = text_splitter.split_documents(docs) - - # 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档 - if len(splitted_docs) > result_len: - normal = NormalizedLevenshtein() - for x in splitted_docs: - x.metadata["score"] = normal.similarity(text, x.page_content) - splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True) - splitted_docs = splitted_docs[:result_len] - - docs = [{"snippet": x.page_content, - "link": x.metadata["link"], - "title": x.metadata["title"]} - for x in splitted_docs] - else: - docs = [{"snippet": x.extract, - "link": x.url, - "title": x.title} - for x in contents] - - return docs - - -SEARCH_ENGINES = {"bing": bing_search, - "duckduckgo": duckduckgo_search, - "metaphor": metaphor_search, - } - - -def search_result2docs(search_results): - docs = [] - for result in search_results: - doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "", - metadata={"source": result["link"] if "link" in result.keys() else "", - "filename": result["title"] if "title" in result.keys() else ""}) - docs.append(doc) - return docs - - -async def lookup_search_engine( - query: str, - search_engine_name: str, - top_k: int = SEARCH_ENGINE_TOP_K, - split_result: bool = False, -): - search_engine = SEARCH_ENGINES[search_engine_name] - results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result) - docs = search_result2docs(results) - return docs - -async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]), - search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), - top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), - history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(None, - description="限制LLM生成Token数量,默认None代表模型最大值"), - prompt_name: str = Body("default", - description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - split_result: bool = Body(False, - description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)") - ): - if search_engine_name not in SEARCH_ENGINES.keys(): - return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") - - if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY: - return BaseResponse(code=404, msg=f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`") - - history = [History.from_data(h) for h in history] - - async def search_engine_chat_iterator(query: str, - search_engine_name: str, - top_k: int, - history: Optional[List[History]], - model_name: str = LLM_MODELS[0], - prompt_name: str = prompt_name, - ) -> AsyncIterable[str]: - nonlocal max_tokens - callback = AsyncIteratorCallbackHandler() - if isinstance(max_tokens, int) and max_tokens <= 0: - max_tokens = None - - model = get_ChatOpenAI( - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - callbacks=[callback], - ) - - docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result) - context = "\n".join([doc.page_content for doc in docs]) - - prompt_template = get_prompt_template("search_engine_chat", prompt_name) - input_msg = History(role="user", content=prompt_template).to_msg_template(False) - chat_prompt = ChatPromptTemplate.from_messages( - [i.to_msg_template() for i in history] + [input_msg]) - - chain = LLMChain(prompt=chat_prompt, llm=model) - - # Begin a task that runs in the background. - task = asyncio.create_task(wrap_done( - chain.acall({"context": context, "question": query}), - callback.done), - ) - - source_documents = [ - f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" - for inum, doc in enumerate(docs) - ] - - if len(source_documents) == 0: # 没有找到相关资料(不太可能) - source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""") - - if stream: - async for token in callback.aiter(): - # Use server-sent-events to stream the response - yield json.dumps({"answer": token}, ensure_ascii=False) - yield json.dumps({"docs": source_documents}, ensure_ascii=False) - else: - answer = "" - async for token in callback.aiter(): - answer += token - yield json.dumps({"answer": answer, - "docs": source_documents}, - ensure_ascii=False) - await task - - return EventSourceResponse(search_engine_chat_iterator(query=query, - search_engine_name=search_engine_name, - top_k=top_k, - history=history, - model_name=model_name, - prompt_name=prompt_name), - ) diff --git a/server/db/models/conversation_model.py b/server/db/models/conversation_model.py index 9cc6d5b6..c9a53bbc 100644 --- a/server/db/models/conversation_model.py +++ b/server/db/models/conversation_model.py @@ -9,7 +9,6 @@ class ConversationModel(Base): __tablename__ = 'conversation' id = Column(String(32), primary_key=True, comment='对话框ID') name = Column(String(50), comment='对话框名称') - # chat/agent_chat等 chat_type = Column(String(50), comment='聊天类型') create_time = Column(DateTime, default=func.now(), comment='创建时间') diff --git a/server/db/models/message_model.py b/server/db/models/message_model.py index 7b76df19..de0bc340 100644 --- a/server/db/models/message_model.py +++ b/server/db/models/message_model.py @@ -10,7 +10,6 @@ class MessageModel(Base): __tablename__ = 'message' id = Column(String(32), primary_key=True, comment='聊天记录ID') conversation_id = Column(String(32), default=None, index=True, comment='对话框ID') - # chat/agent_chat等 chat_type = Column(String(50), comment='聊天类型') query = Column(String(4096), comment='用户问题') response = Column(String(4096), comment='模型回答') diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py index 6558f877..d0d49280 100644 --- a/server/knowledge_base/kb_summary_api.py +++ b/server/knowledge_base/kb_summary_api.py @@ -10,7 +10,6 @@ from typing import List, Optional from server.knowledge_base.kb_summary.base import KBSummaryService from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter from server.utils import wrap_done, get_ChatOpenAI, BaseResponse -from configs import LLM_MODELS, TEMPERATURE from server.knowledge_base.model.kb_document_model import DocumentWithVSId def recreate_summary_vector_store( @@ -19,8 +18,8 @@ def recreate_summary_vector_store( vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), file_description: str = Body(''), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), ): """ @@ -100,8 +99,8 @@ def summary_file_to_vector_store( vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), file_description: str = Body(''), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), ): """ @@ -172,8 +171,8 @@ def summary_doc_ids_to_vector_store( vs_type: str = Body(DEFAULT_VS_TYPE), embed_model: str = Body(EMBEDDING_MODEL), file_description: str = Body(''), - model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + model_name: str = Body(None, description="LLM 模型名称。"), + temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), ) -> BaseResponse: """ diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index f2ddbfd0..61b3625c 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -7,7 +7,6 @@ from configs import ( logger, log_verbose, text_splitter_dict, - LLM_MODELS, TEXT_SPLITTER_NAME, ) import importlib @@ -187,10 +186,10 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): def make_text_splitter( - splitter_name: str = TEXT_SPLITTER_NAME, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, - llm_model: str = LLM_MODELS[0], + splitter_name, + chunk_size, + chunk_overlap, + llm_model, ): """ 根据参数获取特定的分词器 diff --git a/server/llm_api.py b/server/llm_api.py index fbac4937..d642eeac 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,5 +1,5 @@ from fastapi import Body -from configs import logger, log_verbose, LLM_MODELS, HTTPX_DEFAULT_TIMEOUT +from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, get_httpx_client, get_model_worker_config) from typing import List @@ -62,7 +62,7 @@ def get_model_config( def stop_llm_model( - model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODELS[0]]), + model_name: str = Body(..., description="要停止的LLM模型名称", examples=[]), controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) ) -> BaseResponse: ''' @@ -86,8 +86,8 @@ def stop_llm_model( def change_llm_model( - model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODELS[0]]), - new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODELS[0]]), + model_name: str = Body(..., description="当前运行模型"), + new_model_name: str = Body(..., description="要切换的新模型"), controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) ): ''' @@ -108,9 +108,3 @@ def change_llm_model( return BaseResponse( code=500, msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") - - -def list_search_engines() -> BaseResponse: - from server.chat.search_engine_chat import SEARCH_ENGINES - - return BaseResponse(data=list(SEARCH_ENGINES)) diff --git a/server/model_workers/base.py b/server/model_workers/base.py index 234ab47a..b6e88d31 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -1,5 +1,5 @@ from fastchat.conversation import Conversation -from configs import LOG_PATH, TEMPERATURE +from configs import LOG_PATH import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.base_model_worker import BaseModelWorker @@ -63,7 +63,7 @@ class ApiModelParams(ApiConfigParams): deployment_name: Optional[str] = None # for azure resource_name: Optional[str] = None # for azure - temperature: float = TEMPERATURE + temperature: float = 0.9 max_tokens: Optional[int] = None top_p: Optional[float] = 1.0 diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py index 2741b74d..f9ae6cb2 100644 --- a/server/model_workers/qwen.py +++ b/server/model_workers/qwen.py @@ -1,11 +1,6 @@ -import json import sys - from fastchat.conversation import Conversation -from configs import TEMPERATURE -from http import HTTPStatus from typing import List, Literal, Dict - from fastchat import conversation as conv from server.model_workers.base import * from server.model_workers.base import ApiEmbeddingsParams diff --git a/server/utils.py b/server/utils.py index 7fed5f8c..dac08a04 100644 --- a/server/utils.py +++ b/server/utils.py @@ -4,7 +4,7 @@ from typing import List from fastapi import FastAPI from pathlib import Path import asyncio -from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE, +from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE, MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose, FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) import os @@ -402,8 +402,8 @@ def fschat_controller_address() -> str: return f"http://{host}:{port}" -def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str: - if model := get_model_worker_config(model_name): +def fschat_model_worker_address(model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model']))) -> str: + if model := get_model_worker_config(model_name): # TODO: depends fastchat host = model["host"] if host == "0.0.0.0": host = "127.0.0.1" @@ -443,7 +443,7 @@ def webui_address() -> str: def get_prompt_template(type: str, name: str) -> Optional[str]: ''' 从prompt_config中加载模板内容 - type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。 + type: "llm_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。 ''' from configs import prompt_config @@ -617,26 +617,6 @@ def get_server_configs() -> Dict: ''' 获取configs中的原始配置项,供前端使用 ''' - from configs.kb_config import ( - DEFAULT_KNOWLEDGE_BASE, - DEFAULT_SEARCH_ENGINE, - DEFAULT_VS_TYPE, - CHUNK_SIZE, - OVERLAP_SIZE, - SCORE_THRESHOLD, - VECTOR_SEARCH_TOP_K, - SEARCH_ENGINE_TOP_K, - ZH_TITLE_ENHANCE, - text_splitter_dict, - TEXT_SPLITTER_NAME, - ) - from configs.model_config import ( - LLM_MODELS, - HISTORY_LEN, - TEMPERATURE, - ) - from configs.prompt_config import PROMPT_TEMPLATES - _custom = { "controller_address": fschat_controller_address(), "openai_api_address": fschat_openai_api_address(), diff --git a/startup.py b/startup.py index 70b5eccc..60e4b071 100644 --- a/startup.py +++ b/startup.py @@ -8,6 +8,7 @@ from datetime import datetime from pprint import pprint from langchain_core._api import deprecated +# 设置numexpr最大线程数,默认为CPU核心数 try: import numexpr @@ -21,7 +22,7 @@ from configs import ( LOG_PATH, log_verbose, logger, - LLM_MODELS, + LLM_MODEL_CONFIG, EMBEDDING_MODEL, TEXT_SPLITTER_NAME, FSCHAT_CONTROLLER, @@ -32,13 +33,21 @@ from configs import ( HTTPX_DEFAULT_TIMEOUT, ) from server.utils import (fschat_controller_address, fschat_model_worker_address, - fschat_openai_api_address, get_httpx_client, get_model_worker_config, + fschat_openai_api_address, get_httpx_client, + get_model_worker_config, MakeFastAPIOffline, FastAPI, llm_device, embedding_device) from server.knowledge_base.migrate import create_tables import argparse from typing import List, Dict from configs import VERSION +all_model_names = set() +for model_category in LLM_MODEL_CONFIG.values(): + for model_name in model_category.keys(): + if model_name not in all_model_names: + all_model_names.add(model_name) + +all_model_names_list = list(all_model_names) @deprecated( since="0.3.0", @@ -109,9 +118,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: import fastchat.serve.vllm_worker from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id from vllm import AsyncLLMEngine - from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs - args.tokenizer = args.model_path + args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 args.tokenizer_mode = 'auto' args.trust_remote_code = True args.download_dir = None @@ -130,7 +139,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: args.conv_template = None args.limit_worker_concurrency = 5 args.no_register = False - args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量 + args.num_gpus = 4 # vllm worker的切分是tensor并行,这里填写显卡的数量 args.engine_use_ray = False args.disable_log_requests = False @@ -365,7 +374,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None): def run_model_worker( - model_name: str = LLM_MODELS[0], + model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model'])), controller_address: str = "", log_level: str = "INFO", q: mp.Queue = None, @@ -502,7 +511,7 @@ def parse_args() -> argparse.ArgumentParser: "--model-worker", action="store_true", help="run fastchat's model_worker server with specified model name. " - "specify --model-name if not using default LLM_MODELS", + "specify --model-name if not using default llm models", dest="model_worker", ) parser.add_argument( @@ -510,7 +519,7 @@ def parse_args() -> argparse.ArgumentParser: "--model-name", type=str, nargs="+", - default=LLM_MODELS, + default=all_model_names_list, help="specify model name for model worker. " "add addition names with space seperated to start multiple model workers.", dest="model_name", @@ -574,7 +583,7 @@ def dump_server_info(after_start=False, args=None): print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - models = LLM_MODELS + models = list(LLM_MODEL_CONFIG['llm_model'].keys()) if args and args.model_name: models = args.model_name @@ -769,17 +778,17 @@ async def start_main_server(): if p := processes.get("api"): p.start() p.name = f"{p.name} ({p.pid})" - api_started.wait() + api_started.wait() # 等待api.py启动完成 if p := processes.get("webui"): p.start() p.name = f"{p.name} ({p.pid})" - webui_started.wait() + webui_started.wait() # 等待webui.py启动完成 dump_server_info(after_start=True, args=args) while True: - cmd = queue.get() + cmd = queue.get() # 收到切换模型的消息 e = manager.Event() if isinstance(cmd, list): model_name, cmd, new_model_name = cmd @@ -877,20 +886,5 @@ if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - loop.run_until_complete(start_main_server()) -# 服务启动后接口调用示例: -# import openai -# openai.api_key = "EMPTY" # Not support yet -# openai.api_base = "http://localhost:8888/v1" - -# model = "chatglm3-6b" - -# # create a chat completion -# completion = openai.ChatCompletion.create( -# model=model, -# messages=[{"role": "user", "content": "Hello! What is your name?"}] -# ) -# # print the completion -# print(completion.choices[0].message.content) diff --git a/tests/api/test_server_state_api.py b/tests/api/test_server_state_api.py index 59c0985f..2edfb496 100644 --- a/tests/api/test_server_state_api.py +++ b/tests/api/test_server_state_api.py @@ -29,15 +29,7 @@ def test_server_configs(): assert len(configs) > 0 -def test_list_search_engines(): - engines = api.list_search_engines() - pprint(engines) - - assert isinstance(engines, list) - assert len(engines) > 0 - - -@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"]) +@pytest.mark.parametrize("type", ["llm_chat"]) def test_get_prompt_template(type): print(f"prompt template for: {type}") template = api.get_prompt_template(type=type) diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index bf7b0571..daacc560 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -85,29 +85,3 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"): assert "docs" in data and len(data["docs"]) > 0 assert response.status_code == 200 - -def test_search_engine_chat(api="/chat/search_engine_chat"): - global data - - data["query"] = "室温超导最新进展是什么样?" - - url = f"{api_base_url}{api}" - for se in ["bing", "duckduckgo"]: - data["search_engine_name"] = se - dump_input(data, api + f" by {se}") - response = requests.post(url, json=data, stream=True) - if se == "bing" and not BING_SUBSCRIPTION_KEY: - data = response.json() - assert data["code"] == 404 - assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`" - - print("\n") - print("=" * 30 + api + f" by {se} output" + "="*30) - for line in response.iter_content(None, decode_unicode=True): - data = json.loads(line[6:]) - if "answer" in data: - print(data["answer"], end="", flush=True) - assert "docs" in data and len(data["docs"]) > 0 - pprint(data["docs"]) - assert response.status_code == 200 - diff --git a/tests/test_online_api.py b/tests/test_online_api.py index b33d1344..372fad4c 100644 --- a/tests/test_online_api.py +++ b/tests/test_online_api.py @@ -56,15 +56,4 @@ def test_embeddings(worker): assert isinstance(embeddings, list) and len(embeddings) > 0 assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 assert isinstance(embeddings[0][0], float) - print("向量长度:", len(embeddings[0])) - - -# @pytest.mark.parametrize("worker", workers) -# def test_completion(worker): -# params = ApiCompletionParams(prompt="五十六个民族") - -# print(f"\completion with {worker} \n") - -# worker_class = get_model_worker_config(worker)["worker_class"] -# resp = worker_class().do_completion(params) -# pprint(resp) + print("向量长度:", len(embeddings[0])) \ No newline at end of file diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index bf4a6e69..be4934c7 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -6,8 +6,7 @@ from datetime import datetime import os import re import time -from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, LLM_MODELS, - DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL) +from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG) from server.knowledge_base.utils import LOADER_DICT import uuid from typing import List, Dict @@ -55,6 +54,7 @@ def parse_command(text: str, modal: Modal) -> bool: /new {session_name}。如果未提供名称,默认为“会话X” /del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。 /clear {session_name}。如果未提供名称,默认清除当前会话 + /stop {session_name}。如果未提供名称,默认停止当前会话 /help。查看命令帮助 返回值:输入的是命令返回True,否则返回False ''' @@ -117,36 +117,38 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): st.write("\n\n".join(cmds)) with st.sidebar: - # 多会话 conv_names = list(st.session_state["conversation_ids"].keys()) index = 0 + + tools = list(TOOL_CONFIG.keys()) + selected_tool_configs = {} + + with st.expander("工具栏"): + for tool in tools: + is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool) + if is_selected: + selected_tool_configs[tool] = TOOL_CONFIG[tool] + if st.session_state.get("cur_conv_name") in conv_names: index = conv_names.index(st.session_state.get("cur_conv_name")) - conversation_name = st.selectbox("当前会话:", conv_names, index=index) + conversation_name = st.selectbox("当前会话", conv_names, index=index) chat_box.use_chat_name(conversation_name) conversation_id = st.session_state["conversation_ids"][conversation_name] - def on_mode_change(): - mode = st.session_state.dialogue_mode - text = f"已切换到 {mode} 模式。" - if mode == "知识库问答": - cur_kb = st.session_state.get("selected_kb") - if cur_kb: - text = f"{text} 当前知识库: `{cur_kb}`。" - st.toast(text) + # def on_mode_change(): + # mode = st.session_state.dialogue_mode + # text = f"已切换到 {mode} 模式。" + # st.toast(text) - dialogue_modes = ["LLM 对话", - "知识库问答", - "文件对话", - "搜索引擎问答", - "自定义Agent问答", - ] - dialogue_mode = st.selectbox("请选择对话模式:", - dialogue_modes, - index=0, - on_change=on_mode_change, - key="dialogue_mode", - ) + # dialogue_modes = ["智能对话", + # "文件对话", + # ] + # dialogue_mode = st.selectbox("请选择对话模式:", + # dialogue_modes, + # index=0, + # on_change=on_mode_change, + # key="dialogue_mode", + # ) def on_llm_change(): if llm_model: @@ -164,7 +166,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): available_models = [] config_models = api.list_config_models() if not is_lite: - for k, v in config_models.get("local", {}).items(): + for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型 if (v.get("model_path_exists") and k not in running_models): available_models.append(k) @@ -177,103 +179,47 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): index = llm_models.index(cur_llm_model) else: index = 0 - llm_model = st.selectbox("选择LLM模型:", + llm_model = st.selectbox("选择LLM模型", llm_models, index, format_func=llm_model_format_func, on_change=on_llm_change, key="llm_model", ) - if (st.session_state.get("prev_llm_model") != llm_model - and not is_lite - and not llm_model in config_models.get("online", {}) - and not llm_model in config_models.get("langchain", {}) - and llm_model not in running_models): - with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): - prev_model = st.session_state.get("prev_llm_model") - r = api.change_llm_model(prev_model, llm_model) - if msg := check_error_msg(r): - st.error(msg) - elif msg := check_success_msg(r): - st.success(msg) - st.session_state["prev_llm_model"] = llm_model - index_prompt = { - "LLM 对话": "llm_chat", - "自定义Agent问答": "agent_chat", - "搜索引擎问答": "search_engine_chat", - "知识库问答": "knowledge_base_chat", - "文件对话": "knowledge_base_chat", - } - prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys()) - prompt_template_name = prompt_templates_kb_list[0] - if "prompt_template_select" not in st.session_state: - st.session_state.prompt_template_select = prompt_templates_kb_list[0] - def prompt_change(): - text = f"已切换为 {prompt_template_name} 模板。" - st.toast(text) + # 传入后端的内容 + model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} - prompt_template_select = st.selectbox( - "请选择Prompt模板:", - prompt_templates_kb_list, - index=0, - on_change=prompt_change, - key="prompt_template_select", - ) - prompt_template_name = st.session_state.prompt_template_select - temperature = st.slider("Temperature:", 0.0, 2.0, TEMPERATURE, 0.05) - history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN) + for key in LLM_MODEL_CONFIG: + if key == 'llm_model': + continue + if LLM_MODEL_CONFIG[key]: + first_key = next(iter(LLM_MODEL_CONFIG[key])) + model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key] - def on_kb_change(): - st.toast(f"已加载知识库: {st.session_state.selected_kb}") + if llm_model is not None: + model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model] - if dialogue_mode == "知识库问答": - with st.expander("知识库配置", True): - kb_list = api.list_knowledge_bases() - index = 0 - if DEFAULT_KNOWLEDGE_BASE in kb_list: - index = kb_list.index(DEFAULT_KNOWLEDGE_BASE) - selected_kb = st.selectbox( - "请选择知识库:", - kb_list, - index=index, - on_change=on_kb_change, - key="selected_kb", - ) - kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) - - ## Bge 模型会超过1 - score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) - elif dialogue_mode == "文件对话": - with st.expander("文件对话配置", True): - files = st.file_uploader("上传知识文件:", - [i for ls in LOADER_DICT.values() for i in ls], - accept_multiple_files=True, - ) - kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) - - ## Bge 模型会超过1 - score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) - if st.button("开始上传", disabled=len(files) == 0): - st.session_state["file_chat_id"] = upload_temp_docs(files, api) - elif dialogue_mode == "搜索引擎问答": - search_engine_list = api.list_search_engines() - if DEFAULT_SEARCH_ENGINE in search_engine_list: - index = search_engine_list.index(DEFAULT_SEARCH_ENGINE) - else: - index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0 - with st.expander("搜索引擎配置", True): - search_engine = st.selectbox( - label="请选择搜索引擎", - options=search_engine_list, - index=index, - ) - se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K) + print(model_config) + files = st.file_uploader("上传附件", + type=[i for ls in LOADER_DICT.values() for i in ls], + accept_multiple_files=True) + # if dialogue_mode == "文件对话": + # with st.expander("文件对话配置", True): + # files = st.file_uploader("上传知识文件:", + # [i for ls in LOADER_DICT.values() for i in ls], + # accept_multiple_files=True, + # ) + # kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) + # score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) + # if st.button("开始上传", disabled=len(files) == 0): + # st.session_state["file_chat_id"] = upload_temp_docs(files, api) # Display chat messages from history on app rerun - chat_box.output_messages() + + chat_box.output_messages() chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 " def on_feedback( @@ -297,140 +243,76 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): if parse_command(text=prompt, modal=modal): # 用户输入自定义命令 st.rerun() else: - history = get_messages_history(history_len) + history = get_messages_history( + model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"]) chat_box.user_say(prompt) - if dialogue_mode == "LLM 对话": - chat_box.ai_say("正在思考...") - text = "" - message_id = "" - r = api.chat_chat(prompt, - history=history, - conversation_id=conversation_id, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature) - for t in r: - if error_msg := check_error_msg(t): # check whether error occured - st.error(error_msg) - break - text += t.get("text", "") - chat_box.update_msg(text) - message_id = t.get("message_id", "") + chat_box.ai_say("正在思考...") + text = "" + message_id = "" + element_index = 0 + for d in api.chat_chat(query=prompt, + history=history, + model_config=model_config, + conversation_id=conversation_id, + tool_config=selected_tool_configs, + ): + try: + d = json.loads(d) + except: + pass + message_id = d.get("message_id", "") metadata = { "message_id": message_id, } - chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标 - chat_box.show_feedback(**feedback_kwargs, - key=message_id, - on_submit=on_feedback, - kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) - - elif dialogue_mode == "自定义Agent问答": - if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL): - chat_box.ai_say([ - f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n", - Markdown("...", in_expander=True, title="思考过程", state="complete"), - - ]) - else: - chat_box.ai_say([ - f"正在思考...", - Markdown("...", in_expander=True, title="思考过程", state="complete"), - - ]) - text = "" - ans = "" - for d in api.agent_chat(prompt, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature, - ): - try: - d = json.loads(d) - except: - pass - if error_msg := check_error_msg(d): # check whether error occured - st.error(error_msg) - if chunk := d.get("answer"): - text += chunk - chat_box.update_msg(text, element_index=1) - if chunk := d.get("final_answer"): - ans += chunk - chat_box.update_msg(ans, element_index=0) - if chunk := d.get("tools"): - text += "\n\n".join(d.get("tools", [])) - chat_box.update_msg(text, element_index=1) - chat_box.update_msg(ans, element_index=0, streaming=False) - chat_box.update_msg(text, element_index=1, streaming=False) - elif dialogue_mode == "知识库问答": - chat_box.ai_say([ - f"正在查询知识库 `{selected_kb}` ...", - Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"), - ]) - text = "" - for d in api.knowledge_base_chat(prompt, - knowledge_base_name=selected_kb, - top_k=kb_top_k, - score_threshold=score_threshold, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature): - if error_msg := check_error_msg(d): # check whether error occured - st.error(error_msg) - elif chunk := d.get("answer"): - text += chunk - chat_box.update_msg(text, element_index=0) - chat_box.update_msg(text, element_index=0, streaming=False) - chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) - elif dialogue_mode == "文件对话": - if st.session_state["file_chat_id"] is None: - st.error("请先上传文件再进行对话") - st.stop() - chat_box.ai_say([ - f"正在查询文件 `{st.session_state['file_chat_id']}` ...", - Markdown("...", in_expander=True, title="文件匹配结果", state="complete"), - ]) - text = "" - for d in api.file_chat(prompt, - knowledge_id=st.session_state["file_chat_id"], - top_k=kb_top_k, - score_threshold=score_threshold, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature): - if error_msg := check_error_msg(d): # check whether error occured - st.error(error_msg) - elif chunk := d.get("answer"): - text += chunk - chat_box.update_msg(text, element_index=0) - chat_box.update_msg(text, element_index=0, streaming=False) - chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) - elif dialogue_mode == "搜索引擎问答": - chat_box.ai_say([ - f"正在执行 `{search_engine}` 搜索...", - Markdown("...", in_expander=True, title="网络搜索结果", state="complete"), - ]) - text = "" - for d in api.search_engine_chat(prompt, - search_engine_name=search_engine, - top_k=se_top_k, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature, - split_result=se_top_k > 1): - if error_msg := check_error_msg(d): # check whether error occured - st.error(error_msg) - elif chunk := d.get("answer"): - text += chunk - chat_box.update_msg(text, element_index=0) - chat_box.update_msg(text, element_index=0, streaming=False) - chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) + if error_msg := check_error_msg(d): + st.error(error_msg) + if chunk := d.get("agent_action"): + chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete")) + element_index = 1 + formatted_data = { + "action": chunk["tool_name"], + "action_input": chunk["tool_input"] + } + formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False) + text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n" + chat_box.update_msg(text, element_index=element_index, metadata=metadata) + if chunk := d.get("text"): + text += chunk + chat_box.update_msg(text, element_index=element_index, metadata=metadata) + if chunk := d.get("agent_finish"): + element_index = 0 + text = chunk + chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) + chat_box.show_feedback(**feedback_kwargs, + key=message_id, + on_submit=on_feedback, + kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) + # elif dialogue_mode == "文件对话": + # if st.session_state["file_chat_id"] is None: + # st.error("请先上传文件再进行对话") + # st.stop() + # chat_box.ai_say([ + # f"正在查询文件 `{st.session_state['file_chat_id']}` ...", + # Markdown("...", in_expander=True, title="文件匹配结果", state="complete"), + # ]) + # text = "" + # for d in api.file_chat(prompt, + # knowledge_id=st.session_state["file_chat_id"], + # top_k=kb_top_k, + # score_threshold=score_threshold, + # history=history, + # model=llm_model, + # prompt_name=prompt_template_name, + # temperature=temperature): + # if error_msg := check_error_msg(d): + # st.error(error_msg) + # elif chunk := d.get("answer"): + # text += chunk + # chat_box.update_msg(text, element_index=0) + # chat_box.update_msg(text, element_index=0, streaming=False) + # chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) if st.session_state.get("need_rerun"): st.session_state["need_rerun"] = False st.rerun() diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 818748d6..834873d4 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -3,18 +3,15 @@ from typing import * from pathlib import Path -# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同 from configs import ( EMBEDDING_MODEL, DEFAULT_VS_TYPE, - LLM_MODELS, - TEMPERATURE, + LLM_MODEL_CONFIG, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, VECTOR_SEARCH_TOP_K, - SEARCH_ENGINE_TOP_K, HTTPX_DEFAULT_TIMEOUT, logger, log_verbose, ) @@ -26,7 +23,6 @@ from io import BytesIO from server.utils import set_httpx_config, api_address, get_httpx_client from pprint import pprint -from langchain_core._api import deprecated set_httpx_config() @@ -247,10 +243,6 @@ class ApiRequest: response = self.post("/server/configs", **kwargs) return self._get_response_value(response, as_json=True) - def list_search_engines(self, **kwargs) -> List: - response = self.post("/server/list_search_engines", **kwargs) - return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"]) - def get_prompt_template( self, type: str = "llm_chat", @@ -272,10 +264,8 @@ class ApiRequest: history_len: int = -1, history: List[Dict] = [], stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, - max_tokens: int = None, - prompt_name: str = "default", + model_config: Dict = None, + tool_config: Dict = None, **kwargs, ): ''' @@ -287,10 +277,8 @@ class ApiRequest: "history_len": history_len, "history": history, "stream": stream, - "model_name": model, - "temperature": temperature, - "max_tokens": max_tokens, - "prompt_name": prompt_name, + "model_config": model_config, + "tool_config": tool_config, } # print(f"received input message:") @@ -299,78 +287,6 @@ class ApiRequest: response = self.post("/chat/chat", json=data, stream=True, **kwargs) return self._httpx_stream2generator(response, as_json=True) - @deprecated( - since="0.3.0", - message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃", - removal="0.3.0") - def agent_chat( - self, - query: str, - history: List[Dict] = [], - stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, - max_tokens: int = None, - prompt_name: str = "default", - ): - ''' - 对应api.py/chat/agent_chat 接口 - ''' - data = { - "query": query, - "history": history, - "stream": stream, - "model_name": model, - "temperature": temperature, - "max_tokens": max_tokens, - "prompt_name": prompt_name, - } - - # print(f"received input message:") - # pprint(data) - - response = self.post("/chat/agent_chat", json=data, stream=True) - return self._httpx_stream2generator(response, as_json=True) - - def knowledge_base_chat( - self, - query: str, - knowledge_base_name: str, - top_k: int = VECTOR_SEARCH_TOP_K, - score_threshold: float = SCORE_THRESHOLD, - history: List[Dict] = [], - stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, - max_tokens: int = None, - prompt_name: str = "default", - ): - ''' - 对应api.py/chat/knowledge_base_chat接口 - ''' - data = { - "query": query, - "knowledge_base_name": knowledge_base_name, - "top_k": top_k, - "score_threshold": score_threshold, - "history": history, - "stream": stream, - "model_name": model, - "temperature": temperature, - "max_tokens": max_tokens, - "prompt_name": prompt_name, - } - - # print(f"received input message:") - # pprint(data) - - response = self.post( - "/chat/knowledge_base_chat", - json=data, - stream=True, - ) - return self._httpx_stream2generator(response, as_json=True) - def upload_temp_docs( self, files: List[Union[str, Path, bytes]], @@ -416,8 +332,8 @@ class ApiRequest: score_threshold: float = SCORE_THRESHOLD, history: List[Dict] = [], stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, + model: str = None, + temperature: float = 0.9, max_tokens: int = None, prompt_name: str = "default", ): @@ -444,50 +360,6 @@ class ApiRequest: ) return self._httpx_stream2generator(response, as_json=True) - @deprecated( - since="0.3.0", - message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃", - removal="0.3.0" - ) - def search_engine_chat( - self, - query: str, - search_engine_name: str, - top_k: int = SEARCH_ENGINE_TOP_K, - history: List[Dict] = [], - stream: bool = True, - model: str = LLM_MODELS[0], - temperature: float = TEMPERATURE, - max_tokens: int = None, - prompt_name: str = "default", - split_result: bool = False, - ): - ''' - 对应api.py/chat/search_engine_chat接口 - ''' - data = { - "query": query, - "search_engine_name": search_engine_name, - "top_k": top_k, - "history": history, - "stream": stream, - "model_name": model, - "temperature": temperature, - "max_tokens": max_tokens, - "prompt_name": prompt_name, - "split_result": split_result, - } - - # print(f"received input message:") - # pprint(data) - - response = self.post( - "/chat/search_engine_chat", - json=data, - stream=True, - ) - return self._httpx_stream2generator(response, as_json=True) - # 知识库相关操作 def list_knowledge_bases( @@ -769,7 +641,7 @@ class ApiRequest: def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]: ''' 从服务器上获取当前运行的LLM模型。 - 当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。 + 当 local_first=True 时,优先返回运行中的本地模型,否则优先按 LLM_MODEL_CONFIG['llm_model']配置顺序返回。 返回类型为(model_name, is_local_model) ''' @@ -779,7 +651,7 @@ class ApiRequest: return "", False model = "" - for m in LLM_MODELS: + for m in LLM_MODEL_CONFIG['llm_model']: if m not in running_models: continue is_local = not running_models[m].get("online_api") @@ -789,7 +661,7 @@ class ApiRequest: model = m break - if not model: # LLM_MODELS中配置的模型都不在running_models里 + if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里 model = list(running_models)[0] is_local = not running_models[model].get("online_api") return model, is_local @@ -800,7 +672,7 @@ class ApiRequest: return "", False model = "" - for m in LLM_MODELS: + for m in LLM_MODEL_CONFIG['llm_model']: if m not in running_models: continue is_local = not running_models[m].get("online_api") @@ -810,7 +682,7 @@ class ApiRequest: model = m break - if not model: # LLM_MODELS中配置的模型都不在running_models里 + if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里 model = list(running_models)[0] is_local = not running_models[model].get("online_api") return model, is_local @@ -852,15 +724,6 @@ class ApiRequest: ) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) - def list_search_engines(self) -> List[str]: - ''' - 获取服务器支持的搜索引擎 - ''' - response = self.post( - "/server/list_search_engines", - ) - return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) - def stop_llm_model( self, model_name: str,