mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 16:10:18 +08:00
Dev (#2280)
* 修复Azure 不设置Max token的bug * 重写agent 1. 修改Agent实现方式,支持多参数,仅剩 ChatGLM3-6b和 OpenAI GPT4 支持,剩余模型将在暂时缺席Agent功能 2. 删除agent_chat 集成到llm_chat中 3. 重写大部分工具,适应新Agent * 更新架构 * 删除web_chat,自动融合 * 移除所有聊天,都变成Agent控制 * 更新配置文件 * 更新配置模板和提示词 * 更改参数选择bug
This commit is contained in:
parent
808ca227c5
commit
253168a187
@ -1,22 +0,0 @@
|
|||||||
from server.utils import get_ChatOpenAI
|
|
||||||
from configs.model_config import LLM_MODELS, TEMPERATURE
|
|
||||||
from langchain.chains import LLMChain
|
|
||||||
from langchain.prompts.chat import (
|
|
||||||
ChatPromptTemplate,
|
|
||||||
HumanMessagePromptTemplate,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = get_ChatOpenAI(model_name=LLM_MODELS[0], temperature=TEMPERATURE)
|
|
||||||
|
|
||||||
|
|
||||||
human_prompt = "{input}"
|
|
||||||
human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
|
|
||||||
|
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
|
||||||
[("human", "我们来玩成语接龙,我先来,生龙活虎"),
|
|
||||||
("ai", "虎头虎脑"),
|
|
||||||
("human", "{input}")])
|
|
||||||
|
|
||||||
|
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True)
|
|
||||||
print(chain({"input": "恼羞成怒"}))
|
|
||||||
@ -1,14 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# 可以指定一个绝对路径,统一存放所有的Embedding和LLM模型。
|
|
||||||
# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录。
|
|
||||||
# 如果模型目录名称和 MODEL_PATH 中的 key 或 value 相同,程序会自动检测加载,无需修改 MODEL_PATH 中的路径。
|
|
||||||
MODEL_ROOT_PATH = ""
|
MODEL_ROOT_PATH = ""
|
||||||
|
EMBEDDING_MODEL = "bge-large-zh-v1.5" # bge-large-zh
|
||||||
# 选用的 Embedding 名称
|
|
||||||
EMBEDDING_MODEL = "bge-large-zh-v1.5"
|
|
||||||
|
|
||||||
# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
|
|
||||||
EMBEDDING_DEVICE = "auto"
|
EMBEDDING_DEVICE = "auto"
|
||||||
|
|
||||||
# 选用的reranker模型
|
# 选用的reranker模型
|
||||||
@ -21,65 +13,170 @@ RERANKER_MAX_LENGTH = 1024
|
|||||||
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
||||||
EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
||||||
|
|
||||||
# 要运行的 LLM 名称,可以包括本地模型和在线模型。列表中本地模型将在启动项目时全部加载。
|
LLM_MODEL_CONFIG = {
|
||||||
# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
|
# 意图识别不需要输出,模型后台知道就行
|
||||||
# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
|
"preprocess_model": {
|
||||||
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
|
"zhipu-api": {
|
||||||
|
"temperature": 0.4,
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"history_len": 100,
|
||||||
|
"prompt_name": "default",
|
||||||
|
"callbacks": False
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"llm_model": {
|
||||||
|
"chatglm3-6b": {
|
||||||
|
"temperature": 0.9,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"history_len": 3,
|
||||||
|
"prompt_name": "default",
|
||||||
|
"callbacks": True
|
||||||
|
},
|
||||||
|
"zhipu-api": {
|
||||||
|
"temperature": 0.9,
|
||||||
|
"max_tokens": 4000,
|
||||||
|
"history_len": 5,
|
||||||
|
"prompt_name": "default",
|
||||||
|
"callbacks": True
|
||||||
|
},
|
||||||
|
"Qwen-1_8B-Chat": {
|
||||||
|
"temperature": 0.4,
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"history_len": 100,
|
||||||
|
"prompt_name": "default",
|
||||||
|
"callbacks": False
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action_model": {
|
||||||
|
"chatglm3-6b": {
|
||||||
|
"temperature": 0.01,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"prompt_name": "ChatGLM3",
|
||||||
|
"callbacks": True
|
||||||
|
},
|
||||||
|
"zhipu-api": {
|
||||||
|
"temperature": 0.01,
|
||||||
|
"max_tokens": 2096,
|
||||||
|
"history_len": 5,
|
||||||
|
"prompt_name": "ChatGLM3",
|
||||||
|
"callbacks": True
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"postprocess_model": {
|
||||||
|
"chatglm3-6b": {
|
||||||
|
"temperature": 0.01,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"prompt_name": "default",
|
||||||
|
"callbacks": True
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
|
}
|
||||||
Agent_MODEL = None
|
|
||||||
|
TOOL_CONFIG = {
|
||||||
|
"search_local_knowledgebase": {
|
||||||
|
"use": True,
|
||||||
|
"top_k": 10,
|
||||||
|
"score_threshold": 1,
|
||||||
|
"conclude_prompt": {
|
||||||
|
"with_result":
|
||||||
|
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",'
|
||||||
|
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
|
||||||
|
'<已知信息>{{ context }}</已知信息>\n'
|
||||||
|
'<问题>{{ question }}</问题>\n',
|
||||||
|
"without_result":
|
||||||
|
'请你根据我的提问回答我的问题:\n'
|
||||||
|
'{{ question }}\n'
|
||||||
|
'请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"search_internet": {
|
||||||
|
"use": True,
|
||||||
|
"search_engine_name": "bing",
|
||||||
|
"search_engine_config":
|
||||||
|
{
|
||||||
|
"bing": {
|
||||||
|
"result_len": 3,
|
||||||
|
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
|
||||||
|
"bing_key": "",
|
||||||
|
},
|
||||||
|
"metaphor": {
|
||||||
|
"result_len": 3,
|
||||||
|
"metaphor_api_key": "",
|
||||||
|
"split_result": False,
|
||||||
|
"chunk_size": 500,
|
||||||
|
"chunk_overlap": 0,
|
||||||
|
},
|
||||||
|
"duckduckgo": {
|
||||||
|
"result_len": 3
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"top_k": 10,
|
||||||
|
"verbose": "Origin",
|
||||||
|
"conclude_prompt":
|
||||||
|
"<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
|
||||||
|
"</指令>\n<已知信息>{{ context }}</已知信息>\n"
|
||||||
|
"<问题>\n"
|
||||||
|
"{{ question }}\n"
|
||||||
|
"</问题>\n"
|
||||||
|
},
|
||||||
|
"arxiv": {
|
||||||
|
"use": True,
|
||||||
|
},
|
||||||
|
"shell": {
|
||||||
|
"use": True,
|
||||||
|
},
|
||||||
|
"weather_check": {
|
||||||
|
"use": True,
|
||||||
|
"api-key": "",
|
||||||
|
},
|
||||||
|
"search_youtube": {
|
||||||
|
"use": False,
|
||||||
|
},
|
||||||
|
"wolfram": {
|
||||||
|
"use": False,
|
||||||
|
},
|
||||||
|
"calculate": {
|
||||||
|
"use": False,
|
||||||
|
},
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
|
# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
|
||||||
LLM_DEVICE = "auto"
|
LLM_DEVICE = "auto"
|
||||||
|
|
||||||
HISTORY_LEN = 3
|
|
||||||
|
|
||||||
MAX_TOKENS = 2048
|
|
||||||
|
|
||||||
TEMPERATURE = 0.7
|
|
||||||
|
|
||||||
ONLINE_LLM_MODEL = {
|
ONLINE_LLM_MODEL = {
|
||||||
"openai-api": {
|
"openai-api": {
|
||||||
"model_name": "gpt-4",
|
"model_name": "gpt-4-1106-preview",
|
||||||
"api_base_url": "https://api.openai.com/v1",
|
"api_base_url": "https://api.openai.com/v1",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"openai_proxy": "",
|
"openai_proxy": "",
|
||||||
},
|
},
|
||||||
|
|
||||||
# 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
|
|
||||||
"zhipu-api": {
|
"zhipu-api": {
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"version": "glm-4",
|
"version": "chatglm_turbo",
|
||||||
"provider": "ChatGLMWorker",
|
"provider": "ChatGLMWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
# 具体注册及api key获取请前往 https://api.minimax.chat/
|
|
||||||
"minimax-api": {
|
"minimax-api": {
|
||||||
"group_id": "",
|
"group_id": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"is_pro": False,
|
"is_pro": False,
|
||||||
"provider": "MiniMaxWorker",
|
"provider": "MiniMaxWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
|
|
||||||
"xinghuo-api": {
|
"xinghuo-api": {
|
||||||
"APPID": "",
|
"APPID": "",
|
||||||
"APISecret": "",
|
"APISecret": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"version": "v3.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.5","v3.0", "v2.0", "v1.5"
|
"version": "v3.0",
|
||||||
"provider": "XingHuoWorker",
|
"provider": "XingHuoWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
# 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
|
|
||||||
"qianfan-api": {
|
"qianfan-api": {
|
||||||
"version": "ERNIE-Bot", # 注意大小写。当前支持 "ERNIE-Bot" 或 "ERNIE-Bot-turbo", 更多的见官方文档。
|
"version": "ernie-bot-4",
|
||||||
"version_url": "", # 也可以不填写version,直接填写在千帆申请模型发布的API地址
|
"version_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"secret_key": "",
|
"secret_key": "",
|
||||||
"provider": "QianFanWorker",
|
"provider": "QianFanWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
# 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379
|
|
||||||
"fangzhou-api": {
|
"fangzhou-api": {
|
||||||
"version": "chatglm-6b-model",
|
"version": "chatglm-6b-model",
|
||||||
"version_url": "",
|
"version_url": "",
|
||||||
@ -87,28 +184,22 @@ ONLINE_LLM_MODEL = {
|
|||||||
"secret_key": "",
|
"secret_key": "",
|
||||||
"provider": "FangZhouWorker",
|
"provider": "FangZhouWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
# 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
|
||||||
"qwen-api": {
|
"qwen-api": {
|
||||||
"version": "qwen-max",
|
"version": "qwen-max",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"provider": "QwenWorker",
|
"provider": "QwenWorker",
|
||||||
"embed_model": "text-embedding-v1" # embedding 模型名称
|
"embed_model": "text-embedding-v1" # embedding 模型名称
|
||||||
},
|
},
|
||||||
|
|
||||||
# 百川 API,申请方式请参考 https://www.baichuan-ai.com/home#api-enter
|
|
||||||
"baichuan-api": {
|
"baichuan-api": {
|
||||||
"version": "Baichuan2-53B",
|
"version": "Baichuan2-53B",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"secret_key": "",
|
"secret_key": "",
|
||||||
"provider": "BaiChuanWorker",
|
"provider": "BaiChuanWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
# Azure API
|
|
||||||
"azure-api": {
|
"azure-api": {
|
||||||
"deployment_name": "", # 部署容器的名字
|
"deployment_name": "",
|
||||||
"resource_name": "", # https://{resource_name}.openai.azure.com/openai/ 填写resource_name的部分,其他部分不要填写
|
"resource_name": "",
|
||||||
"api_version": "", # API的版本,不是模型版本
|
"api_version": "2023-07-01-preview",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"provider": "AzureWorker",
|
"provider": "AzureWorker",
|
||||||
},
|
},
|
||||||
@ -153,50 +244,33 @@ MODEL_PATH = {
|
|||||||
|
|
||||||
"bge-small-zh": "BAAI/bge-small-zh",
|
"bge-small-zh": "BAAI/bge-small-zh",
|
||||||
"bge-base-zh": "BAAI/bge-base-zh",
|
"bge-base-zh": "BAAI/bge-base-zh",
|
||||||
"bge-large-zh": "BAAI/bge-large-zh",
|
"bge-large-zh": "/media/zr/Data/Models/Embedding/bge-large-zh",
|
||||||
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
|
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
|
||||||
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
|
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
|
||||||
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
|
"bge-large-zh-v1.5": "/Models/bge-large-zh-v1.5",
|
||||||
|
|
||||||
"bge-m3": "BAAI/bge-m3",
|
|
||||||
|
|
||||||
"piccolo-base-zh": "sensenova/piccolo-base-zh",
|
"piccolo-base-zh": "sensenova/piccolo-base-zh",
|
||||||
"piccolo-large-zh": "sensenova/piccolo-large-zh",
|
"piccolo-large-zh": "sensenova/piccolo-large-zh",
|
||||||
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
|
"nlp_gte_sentence-embedding_chinese-large": "/Models/nlp_gte_sentence-embedding_chinese-large",
|
||||||
"text-embedding-ada-002": "your OPENAI_API_KEY",
|
"text-embedding-ada-002": "Just write your OpenAI key like "sk-o3IGBhC9g8AiFvTGWVKsT*****" ",
|
||||||
},
|
},
|
||||||
|
|
||||||
"llm_model": {
|
"llm_model": {
|
||||||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||||
"chatglm3-6b": "THUDM/chatglm3-6b",
|
"chatglm3-6b": "/Models/chatglm3-6b",
|
||||||
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
||||||
|
|
||||||
"Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat",
|
"Yi-34B-Chat": "/data/share/models/Yi-34B-Chat",
|
||||||
"Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin",
|
"BlueLM-7B-Chat": "/Models/BlueLM-7B-Chat",
|
||||||
"Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat",
|
|
||||||
|
|
||||||
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
|
"baichuan2-13b": "/media/zr/Data/Models/LLM/Baichuan2-13B-Chat",
|
||||||
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
|
"baichuan2-7b": "/media/zr/Data/Models/LLM/Baichuan2-7B-Chat",
|
||||||
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
|
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||||
|
"baichuan-13b": "baichuan-inc/Baichuan-13B",
|
||||||
|
'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
|
||||||
|
|
||||||
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
"aquila-7b": "BAAI/Aquila-7B",
|
||||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
"aquilachat-7b": "BAAI/AquilaChat-7B",
|
||||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
|
||||||
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
|
||||||
|
|
||||||
# Qwen1.5 模型 VLLM可能出现问题
|
|
||||||
"Qwen1.5-0.5B-Chat": "Qwen/Qwen1.5-0.5B-Chat",
|
|
||||||
"Qwen1.5-1.8B-Chat": "Qwen/Qwen1.5-1.8B-Chat",
|
|
||||||
"Qwen1.5-4B-Chat": "Qwen/Qwen1.5-4B-Chat",
|
|
||||||
"Qwen1.5-7B-Chat": "Qwen/Qwen1.5-7B-Chat",
|
|
||||||
"Qwen1.5-14B-Chat": "Qwen/Qwen1.5-14B-Chat",
|
|
||||||
"Qwen1.5-72B-Chat": "Qwen/Qwen1.5-72B-Chat",
|
|
||||||
|
|
||||||
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
|
|
||||||
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
|
||||||
"baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat",
|
|
||||||
"baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat",
|
|
||||||
|
|
||||||
"internlm-7b": "internlm/internlm-7b",
|
"internlm-7b": "internlm/internlm-7b",
|
||||||
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
||||||
@ -235,42 +309,43 @@ MODEL_PATH = {
|
|||||||
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
||||||
"dolly-v2-12b": "databricks/dolly-v2-12b",
|
"dolly-v2-12b": "databricks/dolly-v2-12b",
|
||||||
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
|
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
|
||||||
|
|
||||||
|
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
|
||||||
|
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
|
||||||
|
"open_llama_13b": "openlm-research/open_llama_13b",
|
||||||
|
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
|
||||||
|
"koala": "young-geng/koala",
|
||||||
|
|
||||||
|
"mpt-7b": "mosaicml/mpt-7b",
|
||||||
|
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
|
||||||
|
"mpt-30b": "mosaicml/mpt-30b",
|
||||||
|
"opt-66b": "facebook/opt-66b",
|
||||||
|
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
||||||
|
|
||||||
|
"Qwen-1_8B-Chat":"Qwen/Qwen-1_8B-Chat"
|
||||||
|
"Qwen-7B": "Qwen/Qwen-7B",
|
||||||
|
"Qwen-14B": "Qwen/Qwen-14B",
|
||||||
|
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||||
|
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||||
|
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", # 确保已经安装了auto-gptq optimum flash-attn
|
||||||
|
"Qwen-14B-Chat-Int4": "/media/zr/Data/Models/LLM/Qwen-14B-Chat-Int4", # 确保已经安装了auto-gptq optimum flash-attn
|
||||||
},
|
},
|
||||||
|
|
||||||
"reranker": {
|
|
||||||
"bge-reranker-large": "BAAI/bge-reranker-large",
|
|
||||||
"bge-reranker-base": "BAAI/bge-reranker-base",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 通常情况下不需要更改以下内容
|
|
||||||
|
|
||||||
# nltk 模型存储路径
|
|
||||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||||
|
|
||||||
# 使用VLLM可能导致模型推理能力下降,无法完成Agent任务
|
# 使用VLLM可能导致模型推理能力下降,无法完成Agent任务
|
||||||
VLLM_MODEL_DICT = {
|
VLLM_MODEL_DICT = {
|
||||||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
"aquila-7b": "BAAI/Aquila-7B",
|
||||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
"aquilachat-7b": "BAAI/AquilaChat-7B",
|
||||||
"chatglm3-6b": "THUDM/chatglm3-6b",
|
|
||||||
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
|
||||||
|
|
||||||
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
|
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||||
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
|
"baichuan-13b": "baichuan-inc/Baichuan-13B",
|
||||||
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
|
'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
|
||||||
|
|
||||||
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
'chatglm2-6b': 'THUDM/chatglm2-6b',
|
||||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
'chatglm2-6b-32k': 'THUDM/chatglm2-6b-32k',
|
||||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
'chatglm3-6b': 'THUDM/chatglm3-6b',
|
||||||
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
'chatglm3-6b-32k': 'THUDM/chatglm3-6b-32k',
|
||||||
|
|
||||||
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
|
|
||||||
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
|
||||||
"baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
|
|
||||||
"baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
|
||||||
|
|
||||||
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
|
|
||||||
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
|
|
||||||
|
|
||||||
"internlm-7b": "internlm/internlm-7b",
|
"internlm-7b": "internlm/internlm-7b",
|
||||||
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
||||||
@ -301,14 +376,13 @@ VLLM_MODEL_DICT = {
|
|||||||
"opt-66b": "facebook/opt-66b",
|
"opt-66b": "facebook/opt-66b",
|
||||||
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
||||||
|
|
||||||
}
|
"Qwen-7B": "Qwen/Qwen-7B",
|
||||||
|
"Qwen-14B": "Qwen/Qwen-14B",
|
||||||
|
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||||
|
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||||
|
|
||||||
SUPPORT_AGENT_MODEL = [
|
"agentlm-7b": "THUDM/agentlm-7b",
|
||||||
"openai-api", # GPT4 模型
|
"agentlm-13b": "THUDM/agentlm-13b",
|
||||||
"qwen-api", # Qwen Max模型
|
"agentlm-70b": "THUDM/agentlm-70b",
|
||||||
"zhipu-api", # 智谱AI GLM4模型
|
|
||||||
"Qwen", # 所有Qwen系列本地模型
|
}
|
||||||
"chatglm3-6b",
|
|
||||||
"internlm2-chat-20b",
|
|
||||||
"Orion-14B-Chat-Plugin",
|
|
||||||
]
|
|
||||||
|
|||||||
@ -1,26 +1,19 @@
|
|||||||
# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
|
||||||
# 本配置文件支持热加载,修改prompt模板后无需重启服务。
|
|
||||||
|
|
||||||
# LLM对话支持的变量:
|
|
||||||
# - input: 用户输入内容
|
|
||||||
|
|
||||||
# 知识库和搜索引擎对话支持的变量:
|
|
||||||
# - context: 从检索结果拼接的知识文本
|
|
||||||
# - question: 用户提出的问题
|
|
||||||
|
|
||||||
# Agent对话支持的变量:
|
|
||||||
|
|
||||||
# - tools: 可用的工具列表
|
|
||||||
# - tool_names: 可用的工具名称列表
|
|
||||||
# - history: 用户和Agent的对话历史
|
|
||||||
# - input: 用户输入内容
|
|
||||||
# - agent_scratchpad: Agent的思维记录
|
|
||||||
|
|
||||||
PROMPT_TEMPLATES = {
|
PROMPT_TEMPLATES = {
|
||||||
"llm_chat": {
|
"preprocess_model": {
|
||||||
|
"default":
|
||||||
|
'请你根据我的描述和我们对话的历史,来判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。你只要回答一个数字。1 或者 0,1代表需要使用工具,0代表不需要使用工具。\n'
|
||||||
|
'以下几种情况要使用工具,请返回1\n'
|
||||||
|
'1. 实时性的问题,例如天气,日期,地点等信息\n'
|
||||||
|
'2. 需要数学计算的问题\n'
|
||||||
|
'3. 需要查询数据,地点等精确数据\n'
|
||||||
|
'4. 需要行业知识的问题\n'
|
||||||
|
'<question>'
|
||||||
|
'{input}'
|
||||||
|
'</question>'
|
||||||
|
},
|
||||||
|
"llm_model": {
|
||||||
"default":
|
"default":
|
||||||
'{{ input }}',
|
'{{ input }}',
|
||||||
|
|
||||||
"with_history":
|
"with_history":
|
||||||
'The following is a friendly conversation between a human and an AI. '
|
'The following is a friendly conversation between a human and an AI. '
|
||||||
'The AI is talkative and provides lots of specific details from its context. '
|
'The AI is talkative and provides lots of specific details from its context. '
|
||||||
@ -29,72 +22,42 @@ PROMPT_TEMPLATES = {
|
|||||||
'{history}\n'
|
'{history}\n'
|
||||||
'Human: {input}\n'
|
'Human: {input}\n'
|
||||||
'AI:',
|
'AI:',
|
||||||
|
|
||||||
"py":
|
|
||||||
'你是一个聪明的代码助手,请你给我写出简单的py代码。 \n'
|
|
||||||
'{{ input }}',
|
|
||||||
},
|
},
|
||||||
|
"action_model": {
|
||||||
|
"GPT-4":
|
||||||
"knowledge_base_chat": {
|
'Answer the following questions as best you can. You have access to the following tools:\n'
|
||||||
"default":
|
'The way you use the tools is by specifying a json blob.\n'
|
||||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,'
|
'Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).\n'
|
||||||
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
|
'The only values that should be in the "action" field are: {tool_names}\n'
|
||||||
'<已知信息>{{ context }}</已知信息>\n'
|
'The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:\n'
|
||||||
'<问题>{{ question }}</问题>\n',
|
'```\n\n'
|
||||||
|
'{{{{\n'
|
||||||
"text":
|
' "action": $TOOL_NAME,\n'
|
||||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
|
' "action_input": $INPUT\n'
|
||||||
'<已知信息>{{ context }}</已知信息>\n'
|
'}}}}\n'
|
||||||
'<问题>{{ question }}</问题>\n',
|
'```\n\n'
|
||||||
|
'ALWAYS use the following format:\n'
|
||||||
"empty": # 搜不到知识库的时候使用
|
'Question: the input question you must answer\n'
|
||||||
'请你回答我的问题:\n'
|
'Thought: you should always think about what to do\n'
|
||||||
'{{ question }}\n\n',
|
'Action:\n'
|
||||||
},
|
'```\n\n'
|
||||||
|
'$JSON_BLOB'
|
||||||
|
'```\n\n'
|
||||||
"search_engine_chat": {
|
|
||||||
"default":
|
|
||||||
'<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。'
|
|
||||||
'如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 </指令>\n'
|
|
||||||
'<已知信息>{{ context }}</已知信息>\n'
|
|
||||||
'<问题>{{ question }}</问题>\n',
|
|
||||||
|
|
||||||
"search":
|
|
||||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
|
|
||||||
'<已知信息>{{ context }}</已知信息>\n'
|
|
||||||
'<问题>{{ question }}</问题>\n',
|
|
||||||
},
|
|
||||||
|
|
||||||
|
|
||||||
"agent_chat": {
|
|
||||||
"default":
|
|
||||||
'Answer the following questions as best you can. If it is in order, you can use some tools appropriately. '
|
|
||||||
'You have access to the following tools:\n\n'
|
|
||||||
'{tools}\n\n'
|
|
||||||
'Use the following format:\n'
|
|
||||||
'Question: the input question you must answer1\n'
|
|
||||||
'Thought: you should always think about what to do and what tools to use.\n'
|
|
||||||
'Action: the action to take, should be one of [{tool_names}]\n'
|
|
||||||
'Action Input: the input to the action\n'
|
|
||||||
'Observation: the result of the action\n'
|
'Observation: the result of the action\n'
|
||||||
'... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n'
|
'... (this Thought/Action/Observation can repeat N times)\n'
|
||||||
'Thought: I now know the final answer\n'
|
'Thought: I now know the final answer\n'
|
||||||
'Final Answer: the final answer to the original input question\n'
|
'Final Answer: the final answer to the original input question\n'
|
||||||
'Begin!\n\n'
|
'Begin! Reminder to always use the exact characters `Final Answer` when responding.\n'
|
||||||
'history: {history}\n\n'
|
'history: {history}\n'
|
||||||
'Question: {input}\n\n'
|
'Question:{input}\n'
|
||||||
'Thought: {agent_scratchpad}\n',
|
'Thought:{agent_scratchpad}\n',
|
||||||
|
|
||||||
"ChatGLM3":
|
"ChatGLM3":
|
||||||
'You can answer using the tools, or answer directly using your knowledge without using the tools. '
|
'You can answer using the tools.Respond to the human as helpfully and accurately as possible.\n'
|
||||||
'Respond to the human as helpfully and accurately as possible.\n'
|
|
||||||
'You have access to the following tools:\n'
|
'You have access to the following tools:\n'
|
||||||
'{tools}\n'
|
'{tools}\n'
|
||||||
'Use a json blob to specify a tool by providing an action key (tool name) '
|
'Use a json blob to specify a tool by providing an action key (tool name)\n'
|
||||||
'and an action_input key (tool input).\n'
|
'and an action_input key (tool input).\n'
|
||||||
'Valid "action" values: "Final Answer" or [{tool_names}]'
|
'Valid "action" values: "Final Answer" or [{tool_names}]\n'
|
||||||
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
|
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
|
||||||
'```\n'
|
'```\n'
|
||||||
'{{{{\n'
|
'{{{{\n'
|
||||||
@ -118,10 +81,13 @@ PROMPT_TEMPLATES = {
|
|||||||
' "action": "Final Answer",\n'
|
' "action": "Final Answer",\n'
|
||||||
' "action_input": "Final response to human"\n'
|
' "action_input": "Final response to human"\n'
|
||||||
'}}}}\n'
|
'}}}}\n'
|
||||||
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. '
|
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary.\n'
|
||||||
'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n'
|
'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n'
|
||||||
'history: {history}\n\n'
|
'history: {history}\n\n'
|
||||||
'Question: {input}\n\n'
|
'Question: {input}\n\n'
|
||||||
'Thought: {agent_scratchpad}',
|
'Thought: {agent_scratchpad}\n',
|
||||||
|
},
|
||||||
|
"postprocess_model": {
|
||||||
|
"default": "{{input}}",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,7 +57,7 @@ FSCHAT_MODEL_WORKERS = {
|
|||||||
# "awq_ckpt": None,
|
# "awq_ckpt": None,
|
||||||
# "awq_wbits": 16,
|
# "awq_wbits": 16,
|
||||||
# "awq_groupsize": -1,
|
# "awq_groupsize": -1,
|
||||||
# "model_names": LLM_MODELS,
|
# "model_names": None,
|
||||||
# "conv_template": None,
|
# "conv_template": None,
|
||||||
# "limit_worker_concurrency": 5,
|
# "limit_worker_concurrency": 5,
|
||||||
# "stream_interval": 2,
|
# "stream_interval": 2,
|
||||||
|
|||||||
@ -1,29 +0,0 @@
|
|||||||
|
|
||||||
# 实现基于ES的数据插入、检索、删除、更新
|
|
||||||
```shell
|
|
||||||
author: 唐国梁Tommy
|
|
||||||
e-mail: flytang186@qq.com
|
|
||||||
|
|
||||||
如果遇到任何问题,可以与我联系,我这边部署后服务是没有问题的。
|
|
||||||
```
|
|
||||||
|
|
||||||
## 第1步:ES docker部署
|
|
||||||
```shell
|
|
||||||
docker network create elastic
|
|
||||||
docker run -id --name elasticsearch --net elastic -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "xpack.security.enabled=false" -e "xpack.security.http.ssl.enabled=false" -t docker.elastic.co/elasticsearch/elasticsearch:8.8.2
|
|
||||||
```
|
|
||||||
|
|
||||||
### 第2步:Kibana docker部署
|
|
||||||
**注意:Kibana版本与ES保持一致**
|
|
||||||
```shell
|
|
||||||
docker pull docker.elastic.co/kibana/kibana:{version}
|
|
||||||
docker run --name kibana --net elastic -p 5601:5601 docker.elastic.co/kibana/kibana:{version}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 第3步:核心代码
|
|
||||||
```shell
|
|
||||||
1. 核心代码路径
|
|
||||||
server/knowledge_base/kb_service/es_kb_service.py
|
|
||||||
|
|
||||||
2. 需要在 configs/model_config.py 中 配置 ES参数(IP, PORT)等;
|
|
||||||
```
|
|
||||||
@ -1,11 +1,7 @@
|
|||||||
torch==2.1.2
|
# API requirements
|
||||||
torchvision==0.16.2
|
|
||||||
torchaudio==2.1.2
|
langchain>=0.0.346
|
||||||
xformers==0.0.23.post1
|
langchain-experimental>=0.0.42
|
||||||
transformers==4.37.2
|
|
||||||
sentence_transformers==2.2.2
|
|
||||||
langchain==0.0.354
|
|
||||||
langchain-experimental==0.0.47
|
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
fschat==0.2.35
|
fschat==0.2.35
|
||||||
openai==1.9.0
|
openai==1.9.0
|
||||||
|
|||||||
@ -1,11 +1,7 @@
|
|||||||
torch~=2.1.2
|
# API requirements
|
||||||
torchvision~=0.16.2
|
|
||||||
torchaudio~=2.1.2
|
langchain>=0.0.346
|
||||||
xformers>=0.0.23.post1
|
langchain-experimental>=0.0.42
|
||||||
transformers==4.37.2
|
|
||||||
sentence_transformers==2.2.2
|
|
||||||
langchain==0.0.354
|
|
||||||
langchain-experimental==0.0.47
|
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
fschat==0.2.35
|
fschat==0.2.35
|
||||||
openai~=1.9.0
|
openai~=1.9.0
|
||||||
|
|||||||
@ -1,4 +0,0 @@
|
|||||||
from .model_contain import *
|
|
||||||
from .callbacks import *
|
|
||||||
from .custom_template import *
|
|
||||||
from .tools import *
|
|
||||||
1
server/agent/agent_factory/__init__.py
Normal file
1
server/agent/agent_factory/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .glm3_agent import initialize_glm3_agent
|
||||||
@ -3,6 +3,14 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||||
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
|
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||||
|
import os
|
||||||
|
from langchain.agents.agent import Agent
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||||
@ -22,6 +30,7 @@ from langchain.agents.agent import AgentExecutor
|
|||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
from pydantic.schema import model_schema
|
||||||
|
|
||||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -42,8 +51,9 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|||||||
if "tool_call" in text:
|
if "tool_call" in text:
|
||||||
action_end = text.find("```")
|
action_end = text.find("```")
|
||||||
action = text[:action_end].strip()
|
action = text[:action_end].strip()
|
||||||
|
|
||||||
params_str_start = text.find("(") + 1
|
params_str_start = text.find("(") + 1
|
||||||
params_str_end = text.rfind(")")
|
params_str_end = text.find(")")
|
||||||
params_str = text[params_str_start:params_str_end]
|
params_str = text[params_str_start:params_str_end]
|
||||||
|
|
||||||
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
||||||
@ -1,67 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from langchain.agents import Tool, AgentOutputParser
|
|
||||||
from langchain.prompts import StringPromptTemplate
|
|
||||||
from typing import List
|
|
||||||
from langchain.schema import AgentAction, AgentFinish
|
|
||||||
|
|
||||||
from configs import SUPPORT_AGENT_MODEL
|
|
||||||
from server.agent import model_container
|
|
||||||
class CustomPromptTemplate(StringPromptTemplate):
|
|
||||||
template: str
|
|
||||||
tools: List[Tool]
|
|
||||||
|
|
||||||
def format(self, **kwargs) -> str:
|
|
||||||
intermediate_steps = kwargs.pop("intermediate_steps")
|
|
||||||
thoughts = ""
|
|
||||||
for action, observation in intermediate_steps:
|
|
||||||
thoughts += action.log
|
|
||||||
thoughts += f"\nObservation: {observation}\nThought: "
|
|
||||||
kwargs["agent_scratchpad"] = thoughts
|
|
||||||
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
|
|
||||||
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
|
||||||
return self.template.format(**kwargs)
|
|
||||||
|
|
||||||
class CustomOutputParser(AgentOutputParser):
|
|
||||||
begin: bool = False
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.begin = True
|
|
||||||
|
|
||||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
|
||||||
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
|
|
||||||
self.begin = False
|
|
||||||
stop_words = ["Observation:"]
|
|
||||||
min_index = len(llm_output)
|
|
||||||
for stop_word in stop_words:
|
|
||||||
index = llm_output.find(stop_word)
|
|
||||||
if index != -1 and index < min_index:
|
|
||||||
min_index = index
|
|
||||||
llm_output = llm_output[:min_index]
|
|
||||||
|
|
||||||
if "Final Answer:" in llm_output:
|
|
||||||
self.begin = True
|
|
||||||
return AgentFinish(
|
|
||||||
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
|
|
||||||
log=llm_output,
|
|
||||||
)
|
|
||||||
parts = llm_output.split("Action:")
|
|
||||||
if len(parts) < 2:
|
|
||||||
return AgentFinish(
|
|
||||||
return_values={"output": f"调用agent工具失败,该回答为大模型自身能力的回答:\n\n `{llm_output}`"},
|
|
||||||
log=llm_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
action = parts[1].split("Action Input:")[0].strip()
|
|
||||||
action_input = parts[1].split("Action Input:")[1].strip()
|
|
||||||
try:
|
|
||||||
ans = AgentAction(
|
|
||||||
tool=action,
|
|
||||||
tool_input=action_input.strip(" ").strip('"'),
|
|
||||||
log=llm_output
|
|
||||||
)
|
|
||||||
return ans
|
|
||||||
except:
|
|
||||||
return AgentFinish(
|
|
||||||
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
|
||||||
log=llm_output,
|
|
||||||
)
|
|
||||||
@ -1,6 +0,0 @@
|
|||||||
class ModelContainer:
|
|
||||||
def __init__(self):
|
|
||||||
self.MODEL = None
|
|
||||||
self.DATABASE = None
|
|
||||||
|
|
||||||
model_container = ModelContainer()
|
|
||||||
@ -1,11 +0,0 @@
|
|||||||
## 导入所有的工具类
|
|
||||||
from .search_knowledgebase_simple import search_knowledgebase_simple
|
|
||||||
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
|
|
||||||
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
|
|
||||||
from .calculate import calculate, CalculatorInput
|
|
||||||
from .weather_check import weathercheck, WeatherInput
|
|
||||||
from .shell import shell, ShellInput
|
|
||||||
from .search_internet import search_internet, SearchInternetInput
|
|
||||||
from .wolfram import wolfram, WolframInput
|
|
||||||
from .search_youtube import search_youtube, YoutubeInput
|
|
||||||
from .arxiv import arxiv, ArxivInput
|
|
||||||
@ -1,76 +0,0 @@
|
|||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from langchain.chains import LLMMathChain
|
|
||||||
from server.agent import model_container
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
_PROMPT_TEMPLATE = """
|
|
||||||
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
|
|
||||||
问题: ${{包含数学问题的问题。}}
|
|
||||||
```text
|
|
||||||
${{解决问题的单行数学表达式}}
|
|
||||||
```
|
|
||||||
...numexpr.evaluate(query)...
|
|
||||||
```output
|
|
||||||
${{运行代码的输出}}
|
|
||||||
```
|
|
||||||
答案: ${{答案}}
|
|
||||||
|
|
||||||
这是两个例子:
|
|
||||||
|
|
||||||
问题: 37593 * 67是多少?
|
|
||||||
```text
|
|
||||||
37593 * 67
|
|
||||||
```
|
|
||||||
...numexpr.evaluate("37593 * 67")...
|
|
||||||
```output
|
|
||||||
2518731
|
|
||||||
|
|
||||||
答案: 2518731
|
|
||||||
|
|
||||||
问题: 37593的五次方根是多少?
|
|
||||||
```text
|
|
||||||
37593**(1/5)
|
|
||||||
```
|
|
||||||
...numexpr.evaluate("37593**(1/5)")...
|
|
||||||
```output
|
|
||||||
8.222831614237718
|
|
||||||
|
|
||||||
答案: 8.222831614237718
|
|
||||||
|
|
||||||
|
|
||||||
问题: 2的平方是多少?
|
|
||||||
```text
|
|
||||||
2 ** 2
|
|
||||||
```
|
|
||||||
...numexpr.evaluate("2 ** 2")...
|
|
||||||
```output
|
|
||||||
4
|
|
||||||
|
|
||||||
答案: 4
|
|
||||||
|
|
||||||
|
|
||||||
现在,这是我的问题:
|
|
||||||
问题: {question}
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROMPT = PromptTemplate(
|
|
||||||
input_variables=["question"],
|
|
||||||
template=_PROMPT_TEMPLATE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CalculatorInput(BaseModel):
|
|
||||||
query: str = Field()
|
|
||||||
|
|
||||||
def calculate(query: str):
|
|
||||||
model = model_container.MODEL
|
|
||||||
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
|
|
||||||
ans = llm_math.run(query)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
result = calculate("2的三次方")
|
|
||||||
print("答案:",result)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,37 +0,0 @@
|
|||||||
import json
|
|
||||||
from server.chat.search_engine_chat import search_engine_chat
|
|
||||||
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
|
|
||||||
import asyncio
|
|
||||||
from server.agent import model_container
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
async def search_engine_iter(query: str):
|
|
||||||
response = await search_engine_chat(query=query,
|
|
||||||
search_engine_name="bing", # 这里切换搜索引擎
|
|
||||||
model_name=model_container.MODEL.model_name,
|
|
||||||
temperature=0.01, # Agent 搜索互联网的时候,温度设置为0.01
|
|
||||||
history=[],
|
|
||||||
top_k = VECTOR_SEARCH_TOP_K,
|
|
||||||
max_tokens= MAX_TOKENS,
|
|
||||||
prompt_name = "default",
|
|
||||||
stream=False)
|
|
||||||
|
|
||||||
contents = ""
|
|
||||||
|
|
||||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
|
||||||
data = json.loads(data)
|
|
||||||
contents = data["answer"]
|
|
||||||
docs = data["docs"]
|
|
||||||
|
|
||||||
return contents
|
|
||||||
|
|
||||||
def search_internet(query: str):
|
|
||||||
return asyncio.run(search_engine_iter(query))
|
|
||||||
|
|
||||||
class SearchInternetInput(BaseModel):
|
|
||||||
location: str = Field(description="Query for Internet search")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
result = search_internet("今天星期几")
|
|
||||||
print("答案:",result)
|
|
||||||
@ -1,287 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import warnings
|
|
||||||
from typing import Dict
|
|
||||||
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.pydantic_v1 import Extra, root_validator
|
|
||||||
from langchain.schema import BasePromptTemplate
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
from typing import List, Any, Optional
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
|
||||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
|
||||||
import asyncio
|
|
||||||
from server.agent import model_container
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
|
||||||
response = await knowledge_base_chat(query=query,
|
|
||||||
knowledge_base_name=database,
|
|
||||||
model_name=model_container.MODEL.model_name,
|
|
||||||
temperature=0.01,
|
|
||||||
history=[],
|
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
|
||||||
max_tokens=MAX_TOKENS,
|
|
||||||
prompt_name="default",
|
|
||||||
score_threshold=SCORE_THRESHOLD,
|
|
||||||
stream=False)
|
|
||||||
|
|
||||||
contents = ""
|
|
||||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
|
||||||
data = json.loads(data)
|
|
||||||
contents += data["answer"]
|
|
||||||
docs = data["docs"]
|
|
||||||
return contents
|
|
||||||
|
|
||||||
|
|
||||||
async def search_knowledge_multiple(queries) -> List[str]:
|
|
||||||
# queries 应该是一个包含多个 (database, query) 元组的列表
|
|
||||||
tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
|
|
||||||
results = await asyncio.gather(*tasks)
|
|
||||||
# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
|
|
||||||
combined_results = []
|
|
||||||
for (database, _), result in zip(queries, results):
|
|
||||||
message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
|
|
||||||
combined_results.append(message)
|
|
||||||
|
|
||||||
return combined_results
|
|
||||||
|
|
||||||
|
|
||||||
def search_knowledge(queries) -> str:
|
|
||||||
responses = asyncio.run(search_knowledge_multiple(queries))
|
|
||||||
# 输出每个整合的查询结果
|
|
||||||
contents = ""
|
|
||||||
for response in responses:
|
|
||||||
contents += response + "\n\n"
|
|
||||||
return contents
|
|
||||||
|
|
||||||
|
|
||||||
_PROMPT_TEMPLATE = """
|
|
||||||
用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
|
|
||||||
|
|
||||||
对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
|
|
||||||
|
|
||||||
例子:
|
|
||||||
|
|
||||||
robotic,机器人男女比例是多少
|
|
||||||
bigdata,大数据的就业情况如何
|
|
||||||
|
|
||||||
|
|
||||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
|
|
||||||
|
|
||||||
|
|
||||||
{database_names}
|
|
||||||
|
|
||||||
你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
|
|
||||||
不要输出中文的逗号,不要输出引号。
|
|
||||||
|
|
||||||
Question: ${{用户的问题}}
|
|
||||||
|
|
||||||
```text
|
|
||||||
${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}}
|
|
||||||
|
|
||||||
```output
|
|
||||||
数据库查询的结果
|
|
||||||
|
|
||||||
现在,我们开始作答
|
|
||||||
问题: {question}
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROMPT = PromptTemplate(
|
|
||||||
input_variables=["question", "database_names"],
|
|
||||||
template=_PROMPT_TEMPLATE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMKnowledgeChain(LLMChain):
|
|
||||||
llm_chain: LLMChain
|
|
||||||
llm: Optional[BaseLanguageModel] = None
|
|
||||||
"""[Deprecated] LLM wrapper to use."""
|
|
||||||
prompt: BasePromptTemplate = PROMPT
|
|
||||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
|
||||||
database_names: Dict[str, str] = None
|
|
||||||
input_key: str = "question" #: :meta private:
|
|
||||||
output_key: str = "answer" #: :meta private:
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
||||||
if "llm" in values:
|
|
||||||
warnings.warn(
|
|
||||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
|
||||||
"Please instantiate with llm_chain argument or using the from_llm "
|
|
||||||
"class method."
|
|
||||||
)
|
|
||||||
if "llm_chain" not in values and values["llm"] is not None:
|
|
||||||
prompt = values.get("prompt", PROMPT)
|
|
||||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Expect input key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Expect output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
def _evaluate_expression(self, queries) -> str:
|
|
||||||
try:
|
|
||||||
output = search_knowledge(queries)
|
|
||||||
except Exception as e:
|
|
||||||
output = "输入的信息有误或不存在知识库,错误信息如下:\n"
|
|
||||||
return output + str(e)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _process_llm_result(
|
|
||||||
self,
|
|
||||||
llm_output: str,
|
|
||||||
run_manager: CallbackManagerForChainRun
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
|
|
||||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
|
||||||
|
|
||||||
llm_output = llm_output.strip()
|
|
||||||
# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
|
||||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
|
||||||
if text_match:
|
|
||||||
expression = text_match.group(1).strip()
|
|
||||||
cleaned_input_str = (expression.replace("\"", "").replace("“", "").
|
|
||||||
replace("”", "").replace("```", "").strip())
|
|
||||||
lines = cleaned_input_str.split("\n")
|
|
||||||
# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
|
|
||||||
|
|
||||||
try:
|
|
||||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
|
||||||
except:
|
|
||||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
|
||||||
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
|
|
||||||
output = self._evaluate_expression(queries)
|
|
||||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
|
||||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
|
||||||
answer = "Answer: " + output
|
|
||||||
elif llm_output.startswith("Answer:"):
|
|
||||||
answer = llm_output
|
|
||||||
elif "Answer:" in llm_output:
|
|
||||||
answer = llm_output.split("Answer:")[-1]
|
|
||||||
else:
|
|
||||||
return {self.output_key: f"输入的格式不对:\n {llm_output}"}
|
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
async def _aprocess_llm_result(
|
|
||||||
self,
|
|
||||||
llm_output: str,
|
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
|
||||||
llm_output = llm_output.strip()
|
|
||||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
|
||||||
if text_match:
|
|
||||||
|
|
||||||
expression = text_match.group(1).strip()
|
|
||||||
cleaned_input_str = (
|
|
||||||
expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip())
|
|
||||||
lines = cleaned_input_str.split("\n")
|
|
||||||
try:
|
|
||||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
|
||||||
except:
|
|
||||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
|
||||||
await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
|
|
||||||
verbose=self.verbose)
|
|
||||||
|
|
||||||
output = self._evaluate_expression(queries)
|
|
||||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
|
||||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
|
||||||
answer = "Answer: " + output
|
|
||||||
elif llm_output.startswith("Answer:"):
|
|
||||||
answer = llm_output
|
|
||||||
elif "Answer:" in llm_output:
|
|
||||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, str],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
_run_manager.on_text(inputs[self.input_key])
|
|
||||||
self.database_names = model_container.DATABASE
|
|
||||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
|
||||||
llm_output = self.llm_chain.predict(
|
|
||||||
database_names=data_formatted_str,
|
|
||||||
question=inputs[self.input_key],
|
|
||||||
stop=["```output"],
|
|
||||||
callbacks=_run_manager.get_child(),
|
|
||||||
)
|
|
||||||
return self._process_llm_result(llm_output, _run_manager)
|
|
||||||
|
|
||||||
async def _acall(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, str],
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
||||||
await _run_manager.on_text(inputs[self.input_key])
|
|
||||||
self.database_names = model_container.DATABASE
|
|
||||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
|
||||||
llm_output = await self.llm_chain.apredict(
|
|
||||||
database_names=data_formatted_str,
|
|
||||||
question=inputs[self.input_key],
|
|
||||||
stop=["```output"],
|
|
||||||
callbacks=_run_manager.get_child(),
|
|
||||||
)
|
|
||||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "llm_knowledge_chain"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
prompt: BasePromptTemplate = PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMKnowledgeChain:
|
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
||||||
return cls(llm_chain=llm_chain, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def search_knowledgebase_complex(query: str):
|
|
||||||
model = model_container.MODEL
|
|
||||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
|
||||||
ans = llm_knowledge.run(query)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
class KnowledgeSearchInput(BaseModel):
|
|
||||||
location: str = Field(description="The query to be searched")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
# 这是一个正常的切割
|
|
||||||
# queries = [
|
|
||||||
# ("bigdata", "大数据专业的男女比例"),
|
|
||||||
# ("robotic", "机器人专业的优势")
|
|
||||||
# ]
|
|
||||||
# result = search_knowledge(queries)
|
|
||||||
# print(result)
|
|
||||||
@ -1,234 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import re
|
|
||||||
import warnings
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForChainRun,
|
|
||||||
CallbackManagerForChainRun,
|
|
||||||
)
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.pydantic_v1 import Extra, root_validator
|
|
||||||
from langchain.schema import BasePromptTemplate
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
from typing import List, Any, Optional
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
|
||||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from server.agent import model_container
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
async def search_knowledge_base_iter(database: str, query: str):
|
|
||||||
response = await knowledge_base_chat(query=query,
|
|
||||||
knowledge_base_name=database,
|
|
||||||
model_name=model_container.MODEL.model_name,
|
|
||||||
temperature=0.01,
|
|
||||||
history=[],
|
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
|
||||||
max_tokens=MAX_TOKENS,
|
|
||||||
prompt_name="knowledge_base_chat",
|
|
||||||
score_threshold=SCORE_THRESHOLD,
|
|
||||||
stream=False)
|
|
||||||
|
|
||||||
contents = ""
|
|
||||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
|
||||||
data = json.loads(data)
|
|
||||||
contents += data["answer"]
|
|
||||||
docs = data["docs"]
|
|
||||||
return contents
|
|
||||||
|
|
||||||
|
|
||||||
_PROMPT_TEMPLATE = """
|
|
||||||
用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
|
|
||||||
Question: ${{用户的问题}}
|
|
||||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
|
|
||||||
|
|
||||||
{database_names}
|
|
||||||
|
|
||||||
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
|
|
||||||
```text
|
|
||||||
${{知识库的名称}}
|
|
||||||
```
|
|
||||||
```output
|
|
||||||
数据库查询的结果
|
|
||||||
```
|
|
||||||
答案: ${{答案}}
|
|
||||||
|
|
||||||
现在,这是我的问题:
|
|
||||||
问题: {question}
|
|
||||||
|
|
||||||
"""
|
|
||||||
PROMPT = PromptTemplate(
|
|
||||||
input_variables=["question", "database_names"],
|
|
||||||
template=_PROMPT_TEMPLATE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMKnowledgeChain(LLMChain):
|
|
||||||
llm_chain: LLMChain
|
|
||||||
llm: Optional[BaseLanguageModel] = None
|
|
||||||
"""[Deprecated] LLM wrapper to use."""
|
|
||||||
prompt: BasePromptTemplate = PROMPT
|
|
||||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
|
||||||
database_names: Dict[str, str] = model_container.DATABASE
|
|
||||||
input_key: str = "question" #: :meta private:
|
|
||||||
output_key: str = "answer" #: :meta private:
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
||||||
if "llm" in values:
|
|
||||||
warnings.warn(
|
|
||||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
|
||||||
"Please instantiate with llm_chain argument or using the from_llm "
|
|
||||||
"class method."
|
|
||||||
)
|
|
||||||
if "llm_chain" not in values and values["llm"] is not None:
|
|
||||||
prompt = values.get("prompt", PROMPT)
|
|
||||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Expect input key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Expect output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
def _evaluate_expression(self, dataset, query) -> str:
|
|
||||||
try:
|
|
||||||
output = asyncio.run(search_knowledge_base_iter(dataset, query))
|
|
||||||
except Exception as e:
|
|
||||||
output = "输入的信息有误或不存在知识库"
|
|
||||||
return output
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _process_llm_result(
|
|
||||||
self,
|
|
||||||
llm_output: str,
|
|
||||||
llm_input: str,
|
|
||||||
run_manager: CallbackManagerForChainRun
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
|
|
||||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
|
||||||
|
|
||||||
llm_output = llm_output.strip()
|
|
||||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
|
||||||
if text_match:
|
|
||||||
database = text_match.group(1).strip()
|
|
||||||
output = self._evaluate_expression(database, llm_input)
|
|
||||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
|
||||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
|
||||||
answer = "Answer: " + output
|
|
||||||
elif llm_output.startswith("Answer:"):
|
|
||||||
answer = llm_output
|
|
||||||
elif "Answer:" in llm_output:
|
|
||||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
|
||||||
else:
|
|
||||||
return {self.output_key: f"输入的格式不对: {llm_output}"}
|
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
async def _aprocess_llm_result(
|
|
||||||
self,
|
|
||||||
llm_output: str,
|
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
|
||||||
llm_output = llm_output.strip()
|
|
||||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
|
||||||
if text_match:
|
|
||||||
expression = text_match.group(1)
|
|
||||||
output = self._evaluate_expression(expression)
|
|
||||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
|
||||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
|
||||||
answer = "Answer: " + output
|
|
||||||
elif llm_output.startswith("Answer:"):
|
|
||||||
answer = llm_output
|
|
||||||
elif "Answer:" in llm_output:
|
|
||||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, str],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
_run_manager.on_text(inputs[self.input_key])
|
|
||||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
|
||||||
llm_output = self.llm_chain.predict(
|
|
||||||
database_names=data_formatted_str,
|
|
||||||
question=inputs[self.input_key],
|
|
||||||
stop=["```output"],
|
|
||||||
callbacks=_run_manager.get_child(),
|
|
||||||
)
|
|
||||||
return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
|
||||||
|
|
||||||
async def _acall(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, str],
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
||||||
await _run_manager.on_text(inputs[self.input_key])
|
|
||||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
|
||||||
llm_output = await self.llm_chain.apredict(
|
|
||||||
database_names=data_formatted_str,
|
|
||||||
question=inputs[self.input_key],
|
|
||||||
stop=["```output"],
|
|
||||||
callbacks=_run_manager.get_child(),
|
|
||||||
)
|
|
||||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "llm_knowledge_chain"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
prompt: BasePromptTemplate = PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMKnowledgeChain:
|
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
||||||
return cls(llm_chain=llm_chain, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def search_knowledgebase_once(query: str):
|
|
||||||
model = model_container.MODEL
|
|
||||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
|
||||||
ans = llm_knowledge.run(query)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeSearchInput(BaseModel):
|
|
||||||
location: str = Field(description="The query to be searched")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
result = search_knowledgebase_once("大数据的男女比例")
|
|
||||||
print(result)
|
|
||||||
@ -1,32 +0,0 @@
|
|||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
|
||||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
from server.agent import model_container
|
|
||||||
|
|
||||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
|
||||||
response = await knowledge_base_chat(query=query,
|
|
||||||
knowledge_base_name=database,
|
|
||||||
model_name=model_container.MODEL.model_name,
|
|
||||||
temperature=0.01,
|
|
||||||
history=[],
|
|
||||||
top_k=VECTOR_SEARCH_TOP_K,
|
|
||||||
max_tokens=MAX_TOKENS,
|
|
||||||
prompt_name="knowledge_base_chat",
|
|
||||||
score_threshold=SCORE_THRESHOLD,
|
|
||||||
stream=False)
|
|
||||||
|
|
||||||
contents = ""
|
|
||||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
|
||||||
data = json.loads(data)
|
|
||||||
contents = data["answer"]
|
|
||||||
docs = data["docs"]
|
|
||||||
return contents
|
|
||||||
|
|
||||||
def search_knowledgebase_simple(query: str):
|
|
||||||
return asyncio.run(search_knowledge_base_iter(query))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
result = search_knowledgebase_simple("大数据男女比例")
|
|
||||||
print("答案:",result)
|
|
||||||
8
server/agent/tools_factory/__init__.py
Normal file
8
server/agent/tools_factory/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from .search_local_knowledgebase import search_local_knowledgebase, SearchKnowledgeInput
|
||||||
|
from .calculate import calculate, CalculatorInput
|
||||||
|
from .weather_check import weather_check, WeatherInput
|
||||||
|
from .shell import shell, ShellInput
|
||||||
|
from .search_internet import search_internet, SearchInternetInput
|
||||||
|
from .wolfram import wolfram, WolframInput
|
||||||
|
from .search_youtube import search_youtube, YoutubeInput
|
||||||
|
from .arxiv import arxiv, ArxivInput
|
||||||
23
server/agent/tools_factory/calculate.py
Normal file
23
server/agent/tools_factory/calculate.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
def calculate(a: float, b: float, operator: str) -> float:
|
||||||
|
if operator == "+":
|
||||||
|
return a + b
|
||||||
|
elif operator == "-":
|
||||||
|
return a - b
|
||||||
|
elif operator == "*":
|
||||||
|
return a * b
|
||||||
|
elif operator == "/":
|
||||||
|
if b != 0:
|
||||||
|
return a / b
|
||||||
|
else:
|
||||||
|
return float('inf') # 防止除以零
|
||||||
|
elif operator == "^":
|
||||||
|
return a ** b
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported operator")
|
||||||
|
|
||||||
|
class CalculatorInput(BaseModel):
|
||||||
|
a: float = Field(description="first number")
|
||||||
|
b: float = Field(description="second number")
|
||||||
|
operator: str = Field(description="operator to use (e.g., +, -, *, /, ^)")
|
||||||
99
server/agent/tools_factory/search_internet.py
Normal file
99
server/agent/tools_factory/search_internet.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
||||||
|
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
||||||
|
from configs import TOOL_CONFIG
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from typing import List, Dict
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
||||||
|
from markdownify import markdownify
|
||||||
|
|
||||||
|
|
||||||
|
def bing_search(text, config):
|
||||||
|
search = BingSearchAPIWrapper(bing_subscription_key=config["bing_key"],
|
||||||
|
bing_search_url=config["bing_search_url"])
|
||||||
|
return search.results(text, config["result_len"])
|
||||||
|
|
||||||
|
|
||||||
|
def duckduckgo_search(text, config):
|
||||||
|
search = DuckDuckGoSearchAPIWrapper()
|
||||||
|
return search.results(text, config["result_len"])
|
||||||
|
|
||||||
|
|
||||||
|
def metaphor_search(
|
||||||
|
text: str,
|
||||||
|
config: dict,
|
||||||
|
) -> List[Dict]:
|
||||||
|
from metaphor_python import Metaphor
|
||||||
|
client = Metaphor(config["metaphor_api_key"])
|
||||||
|
search = client.search(text, num_results=config["result_len"], use_autoprompt=True)
|
||||||
|
contents = search.get_contents().contents
|
||||||
|
for x in contents:
|
||||||
|
x.extract = markdownify(x.extract)
|
||||||
|
if config["split_result"]:
|
||||||
|
docs = [Document(page_content=x.extract,
|
||||||
|
metadata={"link": x.url, "title": x.title})
|
||||||
|
for x in contents]
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
||||||
|
chunk_size=config["chunk_size"],
|
||||||
|
chunk_overlap=config["chunk_overlap"])
|
||||||
|
splitted_docs = text_splitter.split_documents(docs)
|
||||||
|
if len(splitted_docs) > config["result_len"]:
|
||||||
|
normal = NormalizedLevenshtein()
|
||||||
|
for x in splitted_docs:
|
||||||
|
x.metadata["score"] = normal.similarity(text, x.page_content)
|
||||||
|
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||||||
|
splitted_docs = splitted_docs[:config["result_len"]]
|
||||||
|
|
||||||
|
docs = [{"snippet": x.page_content,
|
||||||
|
"link": x.metadata["link"],
|
||||||
|
"title": x.metadata["title"]}
|
||||||
|
for x in splitted_docs]
|
||||||
|
else:
|
||||||
|
docs = [{"snippet": x.extract,
|
||||||
|
"link": x.url,
|
||||||
|
"title": x.title}
|
||||||
|
for x in contents]
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
SEARCH_ENGINES = {"bing": bing_search,
|
||||||
|
"duckduckgo": duckduckgo_search,
|
||||||
|
"metaphor": metaphor_search,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def search_result2docs(search_results):
|
||||||
|
docs = []
|
||||||
|
for result in search_results:
|
||||||
|
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
|
||||||
|
metadata={"source": result["link"] if "link" in result.keys() else "",
|
||||||
|
"filename": result["title"] if "title" in result.keys() else ""})
|
||||||
|
docs.append(doc)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def search_engine(query: str,
|
||||||
|
config: dict):
|
||||||
|
search_engine_use = SEARCH_ENGINES[config["search_engine_name"]]
|
||||||
|
results = search_engine_use(text=query,
|
||||||
|
config=config["search_engine_config"][
|
||||||
|
config["search_engine_name"]])
|
||||||
|
docs = search_result2docs(results)
|
||||||
|
context = ""
|
||||||
|
docs = [
|
||||||
|
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||||
|
for inum, doc in enumerate(docs)
|
||||||
|
]
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
context += doc + "\n"
|
||||||
|
return context
|
||||||
|
def search_internet(query: str):
|
||||||
|
tool_config = TOOL_CONFIG["search_internet"]
|
||||||
|
return search_engine(query=query, config=tool_config)
|
||||||
|
|
||||||
|
class SearchInternetInput(BaseModel):
|
||||||
|
query: str = Field(description="Query for Internet search")
|
||||||
|
|
||||||
40
server/agent/tools_factory/search_local_knowledgebase.py
Normal file
40
server/agent/tools_factory/search_local_knowledgebase.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from server.knowledge_base.kb_doc_api import search_docs
|
||||||
|
from configs import TOOL_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def search_knowledgebase(query: str, database: str, config: dict):
|
||||||
|
docs = search_docs(
|
||||||
|
query=query,
|
||||||
|
knowledge_base_name=database,
|
||||||
|
top_k=config["top_k"],
|
||||||
|
score_threshold=config["score_threshold"])
|
||||||
|
context = ""
|
||||||
|
source_documents = []
|
||||||
|
for inum, doc in enumerate(docs):
|
||||||
|
filename = doc.metadata.get("source")
|
||||||
|
parameters = urlencode({"knowledge_base_name": database, "file_name": filename})
|
||||||
|
url = f"download_doc?" + parameters
|
||||||
|
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||||
|
source_documents.append(text)
|
||||||
|
|
||||||
|
if len(source_documents) == 0:
|
||||||
|
context= "没有找到相关文档,请更换关键词重试"
|
||||||
|
else:
|
||||||
|
for doc in source_documents:
|
||||||
|
context += doc + "\n"
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
class SearchKnowledgeInput(BaseModel):
|
||||||
|
database: str = Field(description="Database for Knowledge Search")
|
||||||
|
query: str = Field(description="Query for Knowledge Search")
|
||||||
|
|
||||||
|
|
||||||
|
def search_local_knowledgebase(database: str, query: str):
|
||||||
|
tool_config = TOOL_CONFIG["search_local_knowledgebase"]
|
||||||
|
return search_knowledgebase(query=query, database=database, config=tool_config)
|
||||||
@ -6,4 +6,4 @@ def search_youtube(query: str):
|
|||||||
return tool.run(tool_input=query)
|
return tool.run(tool_input=query)
|
||||||
|
|
||||||
class YoutubeInput(BaseModel):
|
class YoutubeInput(BaseModel):
|
||||||
location: str = Field(description="Query for Videos search")
|
query: str = Field(description="Query for Videos search")
|
||||||
@ -6,4 +6,4 @@ def shell(query: str):
|
|||||||
return tool.run(tool_input=query)
|
return tool.run(tool_input=query)
|
||||||
|
|
||||||
class ShellInput(BaseModel):
|
class ShellInput(BaseModel):
|
||||||
query: str = Field(description="一个能在Linux命令行运行的Shell命令")
|
query: str = Field(description="The command to execute")
|
||||||
59
server/agent/tools_factory/tools_registry.py
Normal file
59
server/agent/tools_factory/tools_registry.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from server.agent.tools_factory import *
|
||||||
|
from configs import KB_INFO
|
||||||
|
|
||||||
|
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool."
|
||||||
|
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
||||||
|
template_knowledge = template.format(KB_info=KB_info_str)
|
||||||
|
|
||||||
|
all_tools = [
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=calculate,
|
||||||
|
name="calculate",
|
||||||
|
description="Useful for when you need to answer questions about simple calculations",
|
||||||
|
args_schema=CalculatorInput,
|
||||||
|
),
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=arxiv,
|
||||||
|
name="arxiv",
|
||||||
|
description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
|
||||||
|
args_schema=ArxivInput,
|
||||||
|
),
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=shell,
|
||||||
|
name="shell",
|
||||||
|
description="Use Shell to execute Linux commands",
|
||||||
|
args_schema=ShellInput,
|
||||||
|
),
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=wolfram,
|
||||||
|
name="wolfram",
|
||||||
|
description="Useful for when you need to calculate difficult formulas",
|
||||||
|
args_schema=WolframInput,
|
||||||
|
),
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=search_youtube,
|
||||||
|
name="search_youtube",
|
||||||
|
description="use this tools_factory to search youtube videos",
|
||||||
|
args_schema=YoutubeInput,
|
||||||
|
),
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=weather_check,
|
||||||
|
name="weather_check",
|
||||||
|
description="Use this tool to check the weather at a specific location",
|
||||||
|
args_schema=WeatherInput,
|
||||||
|
),
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=search_internet,
|
||||||
|
name="search_internet",
|
||||||
|
description="Use this tool to use bing search engine to search the internet and get information",
|
||||||
|
args_schema=SearchInternetInput,
|
||||||
|
),
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=search_local_knowledgebase,
|
||||||
|
name="search_local_knowledgebase",
|
||||||
|
description=template_knowledge,
|
||||||
|
args_schema=SearchKnowledgeInput,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
@ -1,10 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
更简单的单参数输入工具实现,用于查询现在天气的情况
|
简单的单参数输入工具实现,用于查询现在天气的情况
|
||||||
"""
|
"""
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import requests
|
import requests
|
||||||
from configs.kb_config import SENIVERSE_API_KEY
|
|
||||||
|
|
||||||
|
|
||||||
def weather(location: str, api_key: str):
|
def weather(location: str, api_key: str):
|
||||||
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
|
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
|
||||||
@ -21,9 +19,7 @@ def weather(location: str, api_key: str):
|
|||||||
f"Failed to retrieve weather: {response.status_code}")
|
f"Failed to retrieve weather: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
def weathercheck(location: str):
|
def weather_check(location: str):
|
||||||
return weather(location, SENIVERSE_API_KEY)
|
return weather(location, "S8vrB4U_-c5mvAMiK")
|
||||||
|
|
||||||
|
|
||||||
class WeatherInput(BaseModel):
|
class WeatherInput(BaseModel):
|
||||||
location: str = Field(description="City name,include city and county")
|
location: str = Field(description="City name,include city and county,like '厦门'")
|
||||||
@ -8,4 +8,4 @@ def wolfram(query: str):
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
class WolframInput(BaseModel):
|
class WolframInput(BaseModel):
|
||||||
location: str = Field(description="需要运算的具体问题")
|
formula: str = Field(description="The formula to be calculated")
|
||||||
@ -1,55 +0,0 @@
|
|||||||
from langchain.tools import Tool
|
|
||||||
from server.agent.tools import *
|
|
||||||
|
|
||||||
tools = [
|
|
||||||
Tool.from_function(
|
|
||||||
func=calculate,
|
|
||||||
name="calculate",
|
|
||||||
description="Useful for when you need to answer questions about simple calculations",
|
|
||||||
args_schema=CalculatorInput,
|
|
||||||
),
|
|
||||||
Tool.from_function(
|
|
||||||
func=arxiv,
|
|
||||||
name="arxiv",
|
|
||||||
description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
|
|
||||||
args_schema=ArxivInput,
|
|
||||||
),
|
|
||||||
Tool.from_function(
|
|
||||||
func=weathercheck,
|
|
||||||
name="weather_check",
|
|
||||||
description="",
|
|
||||||
args_schema=WeatherInput,
|
|
||||||
),
|
|
||||||
Tool.from_function(
|
|
||||||
func=shell,
|
|
||||||
name="shell",
|
|
||||||
description="Use Shell to execute Linux commands",
|
|
||||||
args_schema=ShellInput,
|
|
||||||
),
|
|
||||||
Tool.from_function(
|
|
||||||
func=search_knowledgebase_complex,
|
|
||||||
name="search_knowledgebase_complex",
|
|
||||||
description="Use Use this tool to search local knowledgebase and get information",
|
|
||||||
args_schema=KnowledgeSearchInput,
|
|
||||||
),
|
|
||||||
Tool.from_function(
|
|
||||||
func=search_internet,
|
|
||||||
name="search_internet",
|
|
||||||
description="Use this tool to use bing search engine to search the internet",
|
|
||||||
args_schema=SearchInternetInput,
|
|
||||||
),
|
|
||||||
Tool.from_function(
|
|
||||||
func=wolfram,
|
|
||||||
name="Wolfram",
|
|
||||||
description="Useful for when you need to calculate difficult formulas",
|
|
||||||
args_schema=WolframInput,
|
|
||||||
),
|
|
||||||
Tool.from_function(
|
|
||||||
func=search_youtube,
|
|
||||||
name="search_youtube",
|
|
||||||
description="use this tools to search youtube videos",
|
|
||||||
args_schema=YoutubeInput,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
@ -13,13 +13,12 @@ from fastapi import Body
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
from server.chat.chat import chat
|
from server.chat.chat import chat
|
||||||
from server.chat.search_engine_chat import search_engine_chat
|
|
||||||
from server.chat.completion import completion
|
from server.chat.completion import completion
|
||||||
from server.chat.feedback import chat_feedback
|
from server.chat.feedback import chat_feedback
|
||||||
from server.embeddings_api import embed_texts_endpoint
|
from server.embeddings_api import embed_texts_endpoint
|
||||||
from server.llm_api import (list_running_models, list_config_models,
|
from server.llm_api import (list_running_models, list_config_models,
|
||||||
change_llm_model, stop_llm_model,
|
change_llm_model, stop_llm_model,
|
||||||
get_model_config, list_search_engines)
|
get_model_config)
|
||||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||||
get_server_configs, get_prompt_template)
|
get_server_configs, get_prompt_template)
|
||||||
from typing import List, Literal
|
from typing import List, Literal
|
||||||
@ -63,11 +62,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
|||||||
summary="与llm模型对话(通过LLMChain)",
|
summary="与llm模型对话(通过LLMChain)",
|
||||||
)(chat)
|
)(chat)
|
||||||
|
|
||||||
app.post("/chat/search_engine_chat",
|
|
||||||
tags=["Chat"],
|
|
||||||
summary="与搜索引擎对话",
|
|
||||||
)(search_engine_chat)
|
|
||||||
|
|
||||||
app.post("/chat/feedback",
|
app.post("/chat/feedback",
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="返回llm模型对话评分",
|
summary="返回llm模型对话评分",
|
||||||
@ -110,16 +104,12 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
|||||||
summary="获取服务器原始配置信息",
|
summary="获取服务器原始配置信息",
|
||||||
)(get_server_configs)
|
)(get_server_configs)
|
||||||
|
|
||||||
app.post("/server/list_search_engines",
|
|
||||||
tags=["Server State"],
|
|
||||||
summary="获取服务器支持的搜索引擎",
|
|
||||||
)(list_search_engines)
|
|
||||||
|
|
||||||
@app.post("/server/get_prompt_template",
|
@app.post("/server/get_prompt_template",
|
||||||
tags=["Server State"],
|
tags=["Server State"],
|
||||||
summary="获取服务区配置的 prompt 模板")
|
summary="获取服务区配置的 prompt 模板")
|
||||||
def get_server_prompt_template(
|
def get_server_prompt_template(
|
||||||
type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat,search_engine_chat,agent_chat"),
|
type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat"),
|
||||||
name: str = Body("default", description="模板名称"),
|
name: str = Body("default", description="模板名称"),
|
||||||
) -> str:
|
) -> str:
|
||||||
return get_prompt_template(type=type, name=name)
|
return get_prompt_template(type=type, name=name)
|
||||||
@ -139,7 +129,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
|||||||
def mount_knowledge_routes(app: FastAPI):
|
def mount_knowledge_routes(app: FastAPI):
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
from server.chat.file_chat import upload_temp_docs, file_chat
|
from server.chat.file_chat import upload_temp_docs, file_chat
|
||||||
from server.chat.agent_chat import agent_chat
|
|
||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
update_docs, download_doc, recreate_vector_store,
|
update_docs, download_doc, recreate_vector_store,
|
||||||
@ -155,11 +144,6 @@ def mount_knowledge_routes(app: FastAPI):
|
|||||||
summary="文件对话"
|
summary="文件对话"
|
||||||
)(file_chat)
|
)(file_chat)
|
||||||
|
|
||||||
app.post("/chat/agent_chat",
|
|
||||||
tags=["Chat"],
|
|
||||||
summary="与agent对话")(agent_chat)
|
|
||||||
|
|
||||||
# Tag: Knowledge Base Management
|
|
||||||
app.get("/knowledge_base/list_knowledge_bases",
|
app.get("/knowledge_base/list_knowledge_bases",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=ListResponse,
|
response_model=ListResponse,
|
||||||
|
|||||||
@ -1,13 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
||||||
import json
|
import json
|
||||||
import asyncio
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain.schema import AgentFinish, AgentAction
|
from langchain.schema import AgentFinish, AgentAction
|
||||||
from langchain.schema.output import LLMResult
|
import asyncio
|
||||||
|
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
from langchain.callbacks.base import AsyncCallbackHandler
|
||||||
|
|
||||||
def dumps(obj: Dict) -> str:
|
def dumps(obj: Dict) -> str:
|
||||||
return json.dumps(obj, ensure_ascii=False)
|
return json.dumps(obj, ensure_ascii=False)
|
||||||
@ -23,7 +21,7 @@ class Status:
|
|||||||
tool_finish: int = 7
|
tool_finish: int = 7
|
||||||
|
|
||||||
|
|
||||||
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
@ -31,40 +29,29 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
self.cur_tool = {}
|
self.cur_tool = {}
|
||||||
self.out = True
|
self.out = True
|
||||||
|
|
||||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
async def on_tool_start(
|
||||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
self,
|
||||||
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
serialized: Dict[str, Any],
|
||||||
|
input_str: str,
|
||||||
# 对于截断不能自理的大模型,我来帮他截断
|
*,
|
||||||
stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
|
run_id: UUID,
|
||||||
for stop_word in stop_words:
|
parent_run_id: Optional[UUID] = None,
|
||||||
index = input_str.find(stop_word)
|
tags: Optional[List[str]] = None,
|
||||||
if index != -1:
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
input_str = input_str[:index]
|
**kwargs: Any,
|
||||||
break
|
) -> None:
|
||||||
|
print("on_tool_start")
|
||||||
self.cur_tool = {
|
|
||||||
"tool_name": serialized["name"],
|
|
||||||
"input_str": input_str,
|
|
||||||
"output_str": "",
|
|
||||||
"status": Status.agent_action,
|
|
||||||
"run_id": run_id.hex,
|
|
||||||
"llm_token": "",
|
|
||||||
"final_answer": "",
|
|
||||||
"error": "",
|
|
||||||
}
|
|
||||||
# print("\nInput Str:",self.cur_tool["input_str"])
|
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
|
||||||
|
|
||||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
|
||||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
|
||||||
self.out = True ## 重置输出
|
|
||||||
self.cur_tool.update(
|
|
||||||
status=Status.tool_finish,
|
|
||||||
output_str=output.replace("Answer:", ""),
|
|
||||||
)
|
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
|
||||||
|
|
||||||
|
async def on_tool_end(
|
||||||
|
self,
|
||||||
|
output: str,
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
print("on_tool_end")
|
||||||
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
||||||
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
@ -73,23 +60,6 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
)
|
)
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
||||||
# if "Action" in token: ## 减少重复输出
|
|
||||||
# before_action = token.split("Action")[0]
|
|
||||||
# self.cur_tool.update(
|
|
||||||
# status=Status.running,
|
|
||||||
# llm_token=before_action + "\n",
|
|
||||||
# )
|
|
||||||
# self.queue.put_nowait(dumps(self.cur_tool))
|
|
||||||
#
|
|
||||||
# self.out = False
|
|
||||||
#
|
|
||||||
# if token and self.out:
|
|
||||||
# self.cur_tool.update(
|
|
||||||
# status=Status.running,
|
|
||||||
# llm_token=token,
|
|
||||||
# )
|
|
||||||
# self.queue.put_nowait(dumps(self.cur_tool))
|
|
||||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
special_tokens = ["Action", "<|observation|>"]
|
special_tokens = ["Action", "<|observation|>"]
|
||||||
for stoken in special_tokens:
|
for stoken in special_tokens:
|
||||||
@ -103,7 +73,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
self.out = False
|
self.out = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if token and self.out:
|
if token is not None and token != "" and self.out:
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
status=Status.running,
|
status=Status.running,
|
||||||
llm_token=token,
|
llm_token=token,
|
||||||
@ -116,16 +86,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
llm_token="",
|
llm_token="",
|
||||||
)
|
)
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
async def on_chat_model_start(
|
async def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
messages: List[List],
|
messages: List[List],
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
status=Status.start,
|
status=Status.start,
|
||||||
@ -136,8 +107,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
status=Status.complete,
|
status=Status.complete,
|
||||||
llm_token="\n",
|
llm_token="",
|
||||||
)
|
)
|
||||||
|
self.out = True
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||||
@ -147,15 +119,44 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
)
|
)
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_agent_action(
|
||||||
|
self,
|
||||||
|
action: AgentAction,
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.agent_action,
|
||||||
|
tool_name=action.tool,
|
||||||
|
tool_input=action.tool_input,
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
async def on_agent_finish(
|
async def on_agent_finish(
|
||||||
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
# 返回最终答案
|
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
status=Status.agent_finish,
|
status=Status.agent_finish,
|
||||||
final_answer=finish.return_values["output"],
|
agent_finish=finish.return_values["output"],
|
||||||
)
|
)
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
self.cur_tool = {}
|
|
||||||
|
async def aiter(self) -> AsyncIterator[str]:
|
||||||
|
while not self.queue.empty() or not self.done.is_set():
|
||||||
|
done, other = await asyncio.wait(
|
||||||
|
[
|
||||||
|
asyncio.ensure_future(self.queue.get()),
|
||||||
|
asyncio.ensure_future(self.done.wait()),
|
||||||
|
],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
if other:
|
||||||
|
other.pop().cancel()
|
||||||
|
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
||||||
|
if token_or_done is True:
|
||||||
|
break
|
||||||
|
yield token_or_done
|
||||||
@ -1,178 +0,0 @@
|
|||||||
import json
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from fastapi import Body
|
|
||||||
from sse_starlette.sse import EventSourceResponse
|
|
||||||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
|
||||||
|
|
||||||
from langchain.chains import LLMChain
|
|
||||||
from langchain.memory import ConversationBufferWindowMemory
|
|
||||||
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
|
||||||
from typing import AsyncIterable, Optional, List
|
|
||||||
|
|
||||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
|
||||||
from server.knowledge_base.kb_service.base import get_kb_details
|
|
||||||
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
|
||||||
from server.agent.tools_select import tools, tool_names
|
|
||||||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
|
||||||
from server.chat.utils import History
|
|
||||||
from server.agent import model_container
|
|
||||||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
|
||||||
|
|
||||||
|
|
||||||
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
|
||||||
history: List[History] = Body([],
|
|
||||||
description="历史对话",
|
|
||||||
examples=[[
|
|
||||||
{"role": "user", "content": "请使用知识库工具查询今天北京天气"},
|
|
||||||
{"role": "assistant",
|
|
||||||
"content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]]
|
|
||||||
),
|
|
||||||
stream: bool = Body(False, description="流式输出"),
|
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
|
||||||
prompt_name: str = Body("default",
|
|
||||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
|
||||||
):
|
|
||||||
history = [History.from_data(h) for h in history]
|
|
||||||
|
|
||||||
async def agent_chat_iterator(
|
|
||||||
query: str,
|
|
||||||
history: Optional[List[History]],
|
|
||||||
model_name: str = LLM_MODELS[0],
|
|
||||||
prompt_name: str = prompt_name,
|
|
||||||
) -> AsyncIterable[str]:
|
|
||||||
nonlocal max_tokens
|
|
||||||
callback = CustomAsyncIteratorCallbackHandler()
|
|
||||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
|
||||||
max_tokens = None
|
|
||||||
|
|
||||||
model = get_ChatOpenAI(
|
|
||||||
model_name=model_name,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
callbacks=[callback],
|
|
||||||
)
|
|
||||||
|
|
||||||
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
|
||||||
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
|
||||||
|
|
||||||
if Agent_MODEL:
|
|
||||||
model_agent = get_ChatOpenAI(
|
|
||||||
model_name=Agent_MODEL,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
callbacks=[callback],
|
|
||||||
)
|
|
||||||
model_container.MODEL = model_agent
|
|
||||||
else:
|
|
||||||
model_container.MODEL = model
|
|
||||||
|
|
||||||
prompt_template = get_prompt_template("agent_chat", prompt_name)
|
|
||||||
prompt_template_agent = CustomPromptTemplate(
|
|
||||||
template=prompt_template,
|
|
||||||
tools=tools,
|
|
||||||
input_variables=["input", "intermediate_steps", "history"]
|
|
||||||
)
|
|
||||||
output_parser = CustomOutputParser()
|
|
||||||
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
|
|
||||||
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
|
|
||||||
for message in history:
|
|
||||||
if message.role == 'user':
|
|
||||||
memory.chat_memory.add_user_message(message.content)
|
|
||||||
else:
|
|
||||||
memory.chat_memory.add_ai_message(message.content)
|
|
||||||
if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name:
|
|
||||||
agent_executor = initialize_glm3_agent(
|
|
||||||
llm=model,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=None,
|
|
||||||
prompt=prompt_template,
|
|
||||||
input_variables=["input", "intermediate_steps", "history"],
|
|
||||||
memory=memory,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
agent = LLMSingleActionAgent(
|
|
||||||
llm_chain=llm_chain,
|
|
||||||
output_parser=output_parser,
|
|
||||||
stop=["\nObservation:", "Observation"],
|
|
||||||
allowed_tools=tool_names,
|
|
||||||
)
|
|
||||||
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
|
|
||||||
tools=tools,
|
|
||||||
verbose=True,
|
|
||||||
memory=memory,
|
|
||||||
)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
task = asyncio.create_task(wrap_done(
|
|
||||||
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
|
|
||||||
callback.done))
|
|
||||||
break
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
async for chunk in callback.aiter():
|
|
||||||
tools_use = []
|
|
||||||
# Use server-sent-events to stream the response
|
|
||||||
data = json.loads(chunk)
|
|
||||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
|
||||||
continue
|
|
||||||
elif data["status"] == Status.error:
|
|
||||||
tools_use.append("\n```\n")
|
|
||||||
tools_use.append("工具名称: " + data["tool_name"])
|
|
||||||
tools_use.append("工具状态: " + "调用失败")
|
|
||||||
tools_use.append("错误信息: " + data["error"])
|
|
||||||
tools_use.append("重新开始尝试")
|
|
||||||
tools_use.append("\n```\n")
|
|
||||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
|
||||||
elif data["status"] == Status.tool_finish:
|
|
||||||
tools_use.append("\n```\n")
|
|
||||||
tools_use.append("工具名称: " + data["tool_name"])
|
|
||||||
tools_use.append("工具状态: " + "调用成功")
|
|
||||||
tools_use.append("工具输入: " + data["input_str"])
|
|
||||||
tools_use.append("工具输出: " + data["output_str"])
|
|
||||||
tools_use.append("\n```\n")
|
|
||||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
|
||||||
elif data["status"] == Status.agent_finish:
|
|
||||||
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
|
|
||||||
else:
|
|
||||||
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
|
||||||
answer = ""
|
|
||||||
final_answer = ""
|
|
||||||
async for chunk in callback.aiter():
|
|
||||||
data = json.loads(chunk)
|
|
||||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
|
||||||
continue
|
|
||||||
if data["status"] == Status.error:
|
|
||||||
answer += "\n```\n"
|
|
||||||
answer += "工具名称: " + data["tool_name"] + "\n"
|
|
||||||
answer += "工具状态: " + "调用失败" + "\n"
|
|
||||||
answer += "错误信息: " + data["error"] + "\n"
|
|
||||||
answer += "\n```\n"
|
|
||||||
if data["status"] == Status.tool_finish:
|
|
||||||
answer += "\n```\n"
|
|
||||||
answer += "工具名称: " + data["tool_name"] + "\n"
|
|
||||||
answer += "工具状态: " + "调用成功" + "\n"
|
|
||||||
answer += "工具输入: " + data["input_str"] + "\n"
|
|
||||||
answer += "工具输出: " + data["output_str"] + "\n"
|
|
||||||
answer += "\n```\n"
|
|
||||||
if data["status"] == Status.agent_finish:
|
|
||||||
final_answer = data["final_answer"]
|
|
||||||
else:
|
|
||||||
answer += data["llm_token"]
|
|
||||||
|
|
||||||
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
|
|
||||||
await task
|
|
||||||
|
|
||||||
return EventSourceResponse(agent_chat_iterator(query=query,
|
|
||||||
history=history,
|
|
||||||
model_name=model_name,
|
|
||||||
prompt_name=prompt_name),
|
|
||||||
)
|
|
||||||
@ -1,20 +1,43 @@
|
|||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs import LLM_MODELS, TEMPERATURE
|
from langchain.agents import initialize_agent, AgentType
|
||||||
|
from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.runnables import RunnableBranch
|
||||||
|
from server.agent.agent_factory import initialize_glm3_agent
|
||||||
|
from server.agent.tools_factory.tools_registry import all_tools
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
from server.utils import wrap_done, get_ChatOpenAI
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from typing import AsyncIterable, Dict
|
||||||
from typing import AsyncIterable
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from typing import List, Optional, Union
|
from typing import List, Union
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from server.utils import get_prompt_template
|
from server.utils import get_prompt_template
|
||||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||||
from server.db.repository import add_message_to_db
|
from server.db.repository import add_message_to_db
|
||||||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def create_models_from_config(configs: dict = {}, callbacks: list = []):
|
||||||
|
models = {}
|
||||||
|
prompts = {}
|
||||||
|
for model_type, model_configs in configs.items():
|
||||||
|
for model_name, params in model_configs.items():
|
||||||
|
callback = callbacks if params.get('callbacks', False) else None
|
||||||
|
model_instance = get_ChatOpenAI(
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=params.get('temperature', 0.5),
|
||||||
|
max_tokens=params.get('max_tokens', 1000),
|
||||||
|
callbacks=callback
|
||||||
|
)
|
||||||
|
models[model_type] = model_instance
|
||||||
|
prompt_name = params.get('prompt_name', 'default')
|
||||||
|
prompt_template = get_prompt_template(type=model_type, name=prompt_name)
|
||||||
|
prompts[model_type] = prompt_template
|
||||||
|
return models, prompts
|
||||||
|
|
||||||
|
|
||||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
@ -28,76 +51,106 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
model_config: Dict = Body({}, description="LLM 模型配置。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
tool_config: Dict = Body({}, description="工具配置"),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
|
||||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
|
||||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
|
||||||
):
|
):
|
||||||
async def chat_iterator() -> AsyncIterable[str]:
|
async def chat_iterator() -> AsyncIterable[str]:
|
||||||
nonlocal history, max_tokens
|
nonlocal history
|
||||||
callback = AsyncIteratorCallbackHandler()
|
|
||||||
callbacks = [callback]
|
|
||||||
memory = None
|
memory = None
|
||||||
|
message_id = None
|
||||||
|
chat_prompt = None
|
||||||
|
|
||||||
# 负责保存llm response到message db
|
callback = CustomAsyncIteratorCallbackHandler()
|
||||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
callbacks = [callback]
|
||||||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||||
chat_type="llm_chat",
|
|
||||||
query=query)
|
|
||||||
callbacks.append(conversation_callback)
|
|
||||||
|
|
||||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
if conversation_id:
|
||||||
max_tokens = None
|
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||||
|
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||||
|
|
||||||
model = get_ChatOpenAI(
|
if history:
|
||||||
model_name=model_name,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
if history: # 优先使用前端传入的历史消息
|
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
|
elif conversation_id and history_len > 0:
|
||||||
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
|
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
|
||||||
prompt = get_prompt_template("llm_chat", "with_history")
|
|
||||||
chat_prompt = PromptTemplate.from_template(prompt)
|
|
||||||
# 根据conversation_id 获取message 列表进而拼凑 memory
|
|
||||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id,
|
|
||||||
llm=model,
|
|
||||||
message_limit=history_len)
|
message_limit=history_len)
|
||||||
else:
|
else:
|
||||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
|
||||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||||
|
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
|
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
|
||||||
|
classifier_chain = (
|
||||||
# Begin a task that runs in the background.
|
PromptTemplate.from_template(prompts["preprocess_model"])
|
||||||
task = asyncio.create_task(wrap_done(
|
| models["preprocess_model"]
|
||||||
chain.acall({"input": query}),
|
| StrOutputParser()
|
||||||
callback.done),
|
|
||||||
)
|
)
|
||||||
|
if "chatglm3" in models["action_model"].model_name.lower():
|
||||||
if stream:
|
agent_executor = initialize_glm3_agent(
|
||||||
async for token in callback.aiter():
|
llm=models["action_model"],
|
||||||
# Use server-sent-events to stream the response
|
tools=tools,
|
||||||
yield json.dumps(
|
prompt=prompts["action_model"],
|
||||||
{"text": token, "message_id": message_id},
|
input_variables=["input", "intermediate_steps", "history"],
|
||||||
ensure_ascii=False)
|
memory=memory,
|
||||||
|
callback_manager=BaseCallbackManager(handlers=callbacks),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
answer = ""
|
agent_executor = initialize_agent(
|
||||||
async for token in callback.aiter():
|
llm=models["action_model"],
|
||||||
answer += token
|
tools=tools,
|
||||||
yield json.dumps(
|
callbacks=callbacks,
|
||||||
{"text": answer, "message_id": message_id},
|
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
ensure_ascii=False)
|
memory=memory,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
branch = RunnableBranch(
|
||||||
|
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||||
|
chain
|
||||||
|
)
|
||||||
|
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||||
|
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
||||||
|
if stream:
|
||||||
|
async for chunk in callback.aiter():
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if data["status"] == Status.start:
|
||||||
|
continue
|
||||||
|
elif data["status"] == Status.agent_action:
|
||||||
|
tool_info = {
|
||||||
|
"tool_name": data["tool_name"],
|
||||||
|
"tool_input": data["tool_input"]
|
||||||
|
}
|
||||||
|
yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False)
|
||||||
|
elif data["status"] == Status.agent_finish:
|
||||||
|
yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id},
|
||||||
|
ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
text = ""
|
||||||
|
agent_finish = ""
|
||||||
|
tool_info = None
|
||||||
|
async for chunk in callback.aiter():
|
||||||
|
# Use server-sent-events to stream the response
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if data["status"] == Status.agent_action:
|
||||||
|
tool_info = {
|
||||||
|
"tool_name": data["tool_name"],
|
||||||
|
"tool_input": data["tool_input"]
|
||||||
|
}
|
||||||
|
if data["status"] == Status.agent_finish:
|
||||||
|
agent_finish = data["agent_finish"]
|
||||||
|
else:
|
||||||
|
text += data["llm_token"]
|
||||||
|
if tool_info:
|
||||||
|
yield json.dumps(
|
||||||
|
{"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id},
|
||||||
|
ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
yield json.dumps(
|
||||||
|
{"text": text, "message_id": message_id},
|
||||||
|
ensure_ascii=False)
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return EventSourceResponse(chat_iterator())
|
return EventSourceResponse(chat_iterator())
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs import LLM_MODELS, TEMPERATURE
|
from configs import LLM_MODEL_CONFIG
|
||||||
from server.utils import wrap_done, get_OpenAI
|
from server.utils import wrap_done, get_OpenAI
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
@ -14,8 +14,8 @@ from server.utils import get_prompt_template
|
|||||||
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||||
prompt_name: str = Body("default",
|
prompt_name: str = Body("default",
|
||||||
@ -24,7 +24,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
|||||||
|
|
||||||
#todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
#todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||||
async def completion_iterator(query: str,
|
async def completion_iterator(query: str,
|
||||||
model_name: str = LLM_MODELS[0],
|
model_name: str = None,
|
||||||
prompt_name: str = prompt_name,
|
prompt_name: str = prompt_name,
|
||||||
echo: bool = echo,
|
echo: bool = echo,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from fastapi import Body, File, Form, UploadFile
|
from fastapi import Body, File, Form, UploadFile
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
|
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
|
||||||
from server.utils import (wrap_done, get_ChatOpenAI,
|
from server.utils import (wrap_done, get_ChatOpenAI,
|
||||||
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
||||||
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||||
@ -15,8 +14,6 @@ from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
|||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_files_in_thread(
|
def _parse_files_in_thread(
|
||||||
files: List[UploadFile],
|
files: List[UploadFile],
|
||||||
@ -102,8 +99,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
|||||||
"content": "虎头虎脑"}]]
|
"content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1,147 +0,0 @@
|
|||||||
from fastapi import Body, Request
|
|
||||||
from sse_starlette.sse import EventSourceResponse
|
|
||||||
from fastapi.concurrency import run_in_threadpool
|
|
||||||
from configs import (LLM_MODELS,
|
|
||||||
VECTOR_SEARCH_TOP_K,
|
|
||||||
SCORE_THRESHOLD,
|
|
||||||
TEMPERATURE,
|
|
||||||
USE_RERANKER,
|
|
||||||
RERANKER_MODEL,
|
|
||||||
RERANKER_MAX_LENGTH,
|
|
||||||
MODEL_PATH)
|
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
|
||||||
from server.utils import BaseResponse, get_prompt_template
|
|
||||||
from langchain.chains import LLMChain
|
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
||||||
from typing import AsyncIterable, List, Optional
|
|
||||||
import asyncio
|
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
|
||||||
from server.chat.utils import History
|
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
|
||||||
import json
|
|
||||||
from urllib.parse import urlencode
|
|
||||||
from server.knowledge_base.kb_doc_api import search_docs
|
|
||||||
from server.reranker.reranker import LangchainReranker
|
|
||||||
from server.utils import embedding_device
|
|
||||||
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
|
||||||
score_threshold: float = Body(
|
|
||||||
SCORE_THRESHOLD,
|
|
||||||
description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
|
|
||||||
ge=0,
|
|
||||||
le=2
|
|
||||||
),
|
|
||||||
history: List[History] = Body(
|
|
||||||
[],
|
|
||||||
description="历史对话",
|
|
||||||
examples=[[
|
|
||||||
{"role": "user",
|
|
||||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
|
||||||
{"role": "assistant",
|
|
||||||
"content": "虎头虎脑"}]]
|
|
||||||
),
|
|
||||||
stream: bool = Body(False, description="流式输出"),
|
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
|
||||||
max_tokens: Optional[int] = Body(
|
|
||||||
None,
|
|
||||||
description="限制LLM生成Token数量,默认None代表模型最大值"
|
|
||||||
),
|
|
||||||
prompt_name: str = Body(
|
|
||||||
"default",
|
|
||||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
|
|
||||||
),
|
|
||||||
request: Request = None,
|
|
||||||
):
|
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
|
||||||
if kb is None:
|
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
|
||||||
|
|
||||||
history = [History.from_data(h) for h in history]
|
|
||||||
|
|
||||||
async def knowledge_base_chat_iterator(
|
|
||||||
query: str,
|
|
||||||
top_k: int,
|
|
||||||
history: Optional[List[History]],
|
|
||||||
model_name: str = model_name,
|
|
||||||
prompt_name: str = prompt_name,
|
|
||||||
) -> AsyncIterable[str]:
|
|
||||||
nonlocal max_tokens
|
|
||||||
callback = AsyncIteratorCallbackHandler()
|
|
||||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
|
||||||
max_tokens = None
|
|
||||||
|
|
||||||
model = get_ChatOpenAI(
|
|
||||||
model_name=model_name,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
callbacks=[callback],
|
|
||||||
)
|
|
||||||
docs = await run_in_threadpool(search_docs,
|
|
||||||
query=query,
|
|
||||||
knowledge_base_name=knowledge_base_name,
|
|
||||||
top_k=top_k,
|
|
||||||
score_threshold=score_threshold)
|
|
||||||
|
|
||||||
# 加入reranker
|
|
||||||
if USE_RERANKER:
|
|
||||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
|
||||||
print("-----------------model path------------------")
|
|
||||||
print(reranker_model_path)
|
|
||||||
reranker_model = LangchainReranker(top_n=top_k,
|
|
||||||
device=embedding_device(),
|
|
||||||
max_length=RERANKER_MAX_LENGTH,
|
|
||||||
model_name_or_path=reranker_model_path
|
|
||||||
)
|
|
||||||
print(docs)
|
|
||||||
docs = reranker_model.compress_documents(documents=docs,
|
|
||||||
query=query)
|
|
||||||
print("---------after rerank------------------")
|
|
||||||
print(docs)
|
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
|
||||||
|
|
||||||
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
|
||||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
|
||||||
else:
|
|
||||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
|
||||||
|
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
|
||||||
|
|
||||||
# Begin a task that runs in the background.
|
|
||||||
task = asyncio.create_task(wrap_done(
|
|
||||||
chain.acall({"context": context, "question": query}),
|
|
||||||
callback.done),
|
|
||||||
)
|
|
||||||
|
|
||||||
source_documents = []
|
|
||||||
for inum, doc in enumerate(docs):
|
|
||||||
filename = doc.metadata.get("source")
|
|
||||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
|
|
||||||
base_url = request.base_url
|
|
||||||
url = f"{base_url}knowledge_base/download_doc?" + parameters
|
|
||||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
|
||||||
source_documents.append(text)
|
|
||||||
|
|
||||||
if len(source_documents) == 0: # 没有找到相关文档
|
|
||||||
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
async for token in callback.aiter():
|
|
||||||
# Use server-sent-events to stream the response
|
|
||||||
yield json.dumps({"answer": token}, ensure_ascii=False)
|
|
||||||
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
|
|
||||||
else:
|
|
||||||
answer = ""
|
|
||||||
async for token in callback.aiter():
|
|
||||||
answer += token
|
|
||||||
yield json.dumps({"answer": answer,
|
|
||||||
"docs": source_documents},
|
|
||||||
ensure_ascii=False)
|
|
||||||
await task
|
|
||||||
|
|
||||||
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))
|
|
||||||
|
|
||||||
@ -1,208 +0,0 @@
|
|||||||
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
|
||||||
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
|
||||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
|
||||||
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
|
|
||||||
from langchain.chains import LLMChain
|
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
||||||
|
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
from fastapi import Body
|
|
||||||
from fastapi.concurrency import run_in_threadpool
|
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
|
||||||
from server.utils import BaseResponse, get_prompt_template
|
|
||||||
from server.chat.utils import History
|
|
||||||
from typing import AsyncIterable
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
from typing import List, Optional, Dict
|
|
||||||
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
|
||||||
from markdownify import markdownify
|
|
||||||
|
|
||||||
|
|
||||||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
|
||||||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
|
||||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
|
||||||
"title": "env info is not found",
|
|
||||||
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
|
||||||
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
|
|
||||||
bing_search_url=BING_SEARCH_URL)
|
|
||||||
return search.results(text, result_len)
|
|
||||||
|
|
||||||
|
|
||||||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
|
||||||
search = DuckDuckGoSearchAPIWrapper()
|
|
||||||
return search.results(text, result_len)
|
|
||||||
|
|
||||||
|
|
||||||
def metaphor_search(
|
|
||||||
text: str,
|
|
||||||
result_len: int = SEARCH_ENGINE_TOP_K,
|
|
||||||
split_result: bool = False,
|
|
||||||
chunk_size: int = 500,
|
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
|
||||||
) -> List[Dict]:
|
|
||||||
from metaphor_python import Metaphor
|
|
||||||
|
|
||||||
if not METAPHOR_API_KEY:
|
|
||||||
return []
|
|
||||||
|
|
||||||
client = Metaphor(METAPHOR_API_KEY)
|
|
||||||
search = client.search(text, num_results=result_len, use_autoprompt=True)
|
|
||||||
contents = search.get_contents().contents
|
|
||||||
for x in contents:
|
|
||||||
x.extract = markdownify(x.extract)
|
|
||||||
|
|
||||||
# metaphor 返回的内容都是长文本,需要分词再检索
|
|
||||||
if split_result:
|
|
||||||
docs = [Document(page_content=x.extract,
|
|
||||||
metadata={"link": x.url, "title": x.title})
|
|
||||||
for x in contents]
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap)
|
|
||||||
splitted_docs = text_splitter.split_documents(docs)
|
|
||||||
|
|
||||||
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
|
||||||
if len(splitted_docs) > result_len:
|
|
||||||
normal = NormalizedLevenshtein()
|
|
||||||
for x in splitted_docs:
|
|
||||||
x.metadata["score"] = normal.similarity(text, x.page_content)
|
|
||||||
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
|
|
||||||
splitted_docs = splitted_docs[:result_len]
|
|
||||||
|
|
||||||
docs = [{"snippet": x.page_content,
|
|
||||||
"link": x.metadata["link"],
|
|
||||||
"title": x.metadata["title"]}
|
|
||||||
for x in splitted_docs]
|
|
||||||
else:
|
|
||||||
docs = [{"snippet": x.extract,
|
|
||||||
"link": x.url,
|
|
||||||
"title": x.title}
|
|
||||||
for x in contents]
|
|
||||||
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
SEARCH_ENGINES = {"bing": bing_search,
|
|
||||||
"duckduckgo": duckduckgo_search,
|
|
||||||
"metaphor": metaphor_search,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def search_result2docs(search_results):
|
|
||||||
docs = []
|
|
||||||
for result in search_results:
|
|
||||||
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
|
|
||||||
metadata={"source": result["link"] if "link" in result.keys() else "",
|
|
||||||
"filename": result["title"] if "title" in result.keys() else ""})
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
async def lookup_search_engine(
|
|
||||||
query: str,
|
|
||||||
search_engine_name: str,
|
|
||||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
|
||||||
split_result: bool = False,
|
|
||||||
):
|
|
||||||
search_engine = SEARCH_ENGINES[search_engine_name]
|
|
||||||
results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result)
|
|
||||||
docs = search_result2docs(results)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
|
||||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
|
||||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
|
||||||
history: List[History] = Body([],
|
|
||||||
description="历史对话",
|
|
||||||
examples=[[
|
|
||||||
{"role": "user",
|
|
||||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
|
||||||
{"role": "assistant",
|
|
||||||
"content": "虎头虎脑"}]]
|
|
||||||
),
|
|
||||||
stream: bool = Body(False, description="流式输出"),
|
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
|
||||||
max_tokens: Optional[int] = Body(None,
|
|
||||||
description="限制LLM生成Token数量,默认None代表模型最大值"),
|
|
||||||
prompt_name: str = Body("default",
|
|
||||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
|
||||||
split_result: bool = Body(False,
|
|
||||||
description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
|
|
||||||
):
|
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
|
||||||
|
|
||||||
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
|
|
||||||
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`")
|
|
||||||
|
|
||||||
history = [History.from_data(h) for h in history]
|
|
||||||
|
|
||||||
async def search_engine_chat_iterator(query: str,
|
|
||||||
search_engine_name: str,
|
|
||||||
top_k: int,
|
|
||||||
history: Optional[List[History]],
|
|
||||||
model_name: str = LLM_MODELS[0],
|
|
||||||
prompt_name: str = prompt_name,
|
|
||||||
) -> AsyncIterable[str]:
|
|
||||||
nonlocal max_tokens
|
|
||||||
callback = AsyncIteratorCallbackHandler()
|
|
||||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
|
||||||
max_tokens = None
|
|
||||||
|
|
||||||
model = get_ChatOpenAI(
|
|
||||||
model_name=model_name,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
callbacks=[callback],
|
|
||||||
)
|
|
||||||
|
|
||||||
docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result)
|
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
|
||||||
|
|
||||||
prompt_template = get_prompt_template("search_engine_chat", prompt_name)
|
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
|
||||||
|
|
||||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
|
||||||
|
|
||||||
# Begin a task that runs in the background.
|
|
||||||
task = asyncio.create_task(wrap_done(
|
|
||||||
chain.acall({"context": context, "question": query}),
|
|
||||||
callback.done),
|
|
||||||
)
|
|
||||||
|
|
||||||
source_documents = [
|
|
||||||
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
|
||||||
for inum, doc in enumerate(docs)
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
|
|
||||||
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
async for token in callback.aiter():
|
|
||||||
# Use server-sent-events to stream the response
|
|
||||||
yield json.dumps({"answer": token}, ensure_ascii=False)
|
|
||||||
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
|
|
||||||
else:
|
|
||||||
answer = ""
|
|
||||||
async for token in callback.aiter():
|
|
||||||
answer += token
|
|
||||||
yield json.dumps({"answer": answer,
|
|
||||||
"docs": source_documents},
|
|
||||||
ensure_ascii=False)
|
|
||||||
await task
|
|
||||||
|
|
||||||
return EventSourceResponse(search_engine_chat_iterator(query=query,
|
|
||||||
search_engine_name=search_engine_name,
|
|
||||||
top_k=top_k,
|
|
||||||
history=history,
|
|
||||||
model_name=model_name,
|
|
||||||
prompt_name=prompt_name),
|
|
||||||
)
|
|
||||||
@ -9,7 +9,6 @@ class ConversationModel(Base):
|
|||||||
__tablename__ = 'conversation'
|
__tablename__ = 'conversation'
|
||||||
id = Column(String(32), primary_key=True, comment='对话框ID')
|
id = Column(String(32), primary_key=True, comment='对话框ID')
|
||||||
name = Column(String(50), comment='对话框名称')
|
name = Column(String(50), comment='对话框名称')
|
||||||
# chat/agent_chat等
|
|
||||||
chat_type = Column(String(50), comment='聊天类型')
|
chat_type = Column(String(50), comment='聊天类型')
|
||||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,6 @@ class MessageModel(Base):
|
|||||||
__tablename__ = 'message'
|
__tablename__ = 'message'
|
||||||
id = Column(String(32), primary_key=True, comment='聊天记录ID')
|
id = Column(String(32), primary_key=True, comment='聊天记录ID')
|
||||||
conversation_id = Column(String(32), default=None, index=True, comment='对话框ID')
|
conversation_id = Column(String(32), default=None, index=True, comment='对话框ID')
|
||||||
# chat/agent_chat等
|
|
||||||
chat_type = Column(String(50), comment='聊天类型')
|
chat_type = Column(String(50), comment='聊天类型')
|
||||||
query = Column(String(4096), comment='用户问题')
|
query = Column(String(4096), comment='用户问题')
|
||||||
response = Column(String(4096), comment='模型回答')
|
response = Column(String(4096), comment='模型回答')
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from typing import List, Optional
|
|||||||
from server.knowledge_base.kb_summary.base import KBSummaryService
|
from server.knowledge_base.kb_summary.base import KBSummaryService
|
||||||
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
|
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
|
||||||
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
|
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
|
||||||
from configs import LLM_MODELS, TEMPERATURE
|
|
||||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
|
|
||||||
def recreate_summary_vector_store(
|
def recreate_summary_vector_store(
|
||||||
@ -19,8 +18,8 @@ def recreate_summary_vector_store(
|
|||||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
file_description: str = Body(''),
|
file_description: str = Body(''),
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -100,8 +99,8 @@ def summary_file_to_vector_store(
|
|||||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
file_description: str = Body(''),
|
file_description: str = Body(''),
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -172,8 +171,8 @@ def summary_doc_ids_to_vector_store(
|
|||||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
file_description: str = Body(''),
|
file_description: str = Body(''),
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from configs import (
|
|||||||
logger,
|
logger,
|
||||||
log_verbose,
|
log_verbose,
|
||||||
text_splitter_dict,
|
text_splitter_dict,
|
||||||
LLM_MODELS,
|
|
||||||
TEXT_SPLITTER_NAME,
|
TEXT_SPLITTER_NAME,
|
||||||
)
|
)
|
||||||
import importlib
|
import importlib
|
||||||
@ -187,10 +186,10 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
|||||||
|
|
||||||
|
|
||||||
def make_text_splitter(
|
def make_text_splitter(
|
||||||
splitter_name: str = TEXT_SPLITTER_NAME,
|
splitter_name,
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap,
|
||||||
llm_model: str = LLM_MODELS[0],
|
llm_model,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
根据参数获取特定的分词器
|
根据参数获取特定的分词器
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from configs import logger, log_verbose, LLM_MODELS, HTTPX_DEFAULT_TIMEOUT
|
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT
|
||||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||||
get_httpx_client, get_model_worker_config)
|
get_httpx_client, get_model_worker_config)
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -62,7 +62,7 @@ def get_model_config(
|
|||||||
|
|
||||||
|
|
||||||
def stop_llm_model(
|
def stop_llm_model(
|
||||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODELS[0]]),
|
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[]),
|
||||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
'''
|
'''
|
||||||
@ -86,8 +86,8 @@ def stop_llm_model(
|
|||||||
|
|
||||||
|
|
||||||
def change_llm_model(
|
def change_llm_model(
|
||||||
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODELS[0]]),
|
model_name: str = Body(..., description="当前运行模型"),
|
||||||
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODELS[0]]),
|
new_model_name: str = Body(..., description="要切换的新模型"),
|
||||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
@ -108,9 +108,3 @@ def change_llm_model(
|
|||||||
return BaseResponse(
|
return BaseResponse(
|
||||||
code=500,
|
code=500,
|
||||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
||||||
|
|
||||||
|
|
||||||
def list_search_engines() -> BaseResponse:
|
|
||||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
|
||||||
|
|
||||||
return BaseResponse(data=list(SEARCH_ENGINES))
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from fastchat.conversation import Conversation
|
from fastchat.conversation import Conversation
|
||||||
from configs import LOG_PATH, TEMPERATURE
|
from configs import LOG_PATH
|
||||||
import fastchat.constants
|
import fastchat.constants
|
||||||
fastchat.constants.LOGDIR = LOG_PATH
|
fastchat.constants.LOGDIR = LOG_PATH
|
||||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||||
@ -63,7 +63,7 @@ class ApiModelParams(ApiConfigParams):
|
|||||||
deployment_name: Optional[str] = None # for azure
|
deployment_name: Optional[str] = None # for azure
|
||||||
resource_name: Optional[str] = None # for azure
|
resource_name: Optional[str] = None # for azure
|
||||||
|
|
||||||
temperature: float = TEMPERATURE
|
temperature: float = 0.9
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = 1.0
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,6 @@
|
|||||||
import json
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from fastchat.conversation import Conversation
|
from fastchat.conversation import Conversation
|
||||||
from configs import TEMPERATURE
|
|
||||||
from http import HTTPStatus
|
|
||||||
from typing import List, Literal, Dict
|
from typing import List, Literal, Dict
|
||||||
|
|
||||||
from fastchat import conversation as conv
|
from fastchat import conversation as conv
|
||||||
from server.model_workers.base import *
|
from server.model_workers.base import *
|
||||||
from server.model_workers.base import ApiEmbeddingsParams
|
from server.model_workers.base import ApiEmbeddingsParams
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import List
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
|
from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
|
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
|
||||||
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
||||||
import os
|
import os
|
||||||
@ -402,8 +402,8 @@ def fschat_controller_address() -> str:
|
|||||||
return f"http://{host}:{port}"
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
|
def fschat_model_worker_address(model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model']))) -> str:
|
||||||
if model := get_model_worker_config(model_name):
|
if model := get_model_worker_config(model_name): # TODO: depends fastchat
|
||||||
host = model["host"]
|
host = model["host"]
|
||||||
if host == "0.0.0.0":
|
if host == "0.0.0.0":
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
@ -443,7 +443,7 @@ def webui_address() -> str:
|
|||||||
def get_prompt_template(type: str, name: str) -> Optional[str]:
|
def get_prompt_template(type: str, name: str) -> Optional[str]:
|
||||||
'''
|
'''
|
||||||
从prompt_config中加载模板内容
|
从prompt_config中加载模板内容
|
||||||
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
|
type: "llm_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from configs import prompt_config
|
from configs import prompt_config
|
||||||
@ -617,26 +617,6 @@ def get_server_configs() -> Dict:
|
|||||||
'''
|
'''
|
||||||
获取configs中的原始配置项,供前端使用
|
获取configs中的原始配置项,供前端使用
|
||||||
'''
|
'''
|
||||||
from configs.kb_config import (
|
|
||||||
DEFAULT_KNOWLEDGE_BASE,
|
|
||||||
DEFAULT_SEARCH_ENGINE,
|
|
||||||
DEFAULT_VS_TYPE,
|
|
||||||
CHUNK_SIZE,
|
|
||||||
OVERLAP_SIZE,
|
|
||||||
SCORE_THRESHOLD,
|
|
||||||
VECTOR_SEARCH_TOP_K,
|
|
||||||
SEARCH_ENGINE_TOP_K,
|
|
||||||
ZH_TITLE_ENHANCE,
|
|
||||||
text_splitter_dict,
|
|
||||||
TEXT_SPLITTER_NAME,
|
|
||||||
)
|
|
||||||
from configs.model_config import (
|
|
||||||
LLM_MODELS,
|
|
||||||
HISTORY_LEN,
|
|
||||||
TEMPERATURE,
|
|
||||||
)
|
|
||||||
from configs.prompt_config import PROMPT_TEMPLATES
|
|
||||||
|
|
||||||
_custom = {
|
_custom = {
|
||||||
"controller_address": fschat_controller_address(),
|
"controller_address": fschat_controller_address(),
|
||||||
"openai_api_address": fschat_openai_api_address(),
|
"openai_api_address": fschat_openai_api_address(),
|
||||||
|
|||||||
48
startup.py
48
startup.py
@ -8,6 +8,7 @@ from datetime import datetime
|
|||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
|
|
||||||
|
# 设置numexpr最大线程数,默认为CPU核心数
|
||||||
try:
|
try:
|
||||||
import numexpr
|
import numexpr
|
||||||
|
|
||||||
@ -21,7 +22,7 @@ from configs import (
|
|||||||
LOG_PATH,
|
LOG_PATH,
|
||||||
log_verbose,
|
log_verbose,
|
||||||
logger,
|
logger,
|
||||||
LLM_MODELS,
|
LLM_MODEL_CONFIG,
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
TEXT_SPLITTER_NAME,
|
TEXT_SPLITTER_NAME,
|
||||||
FSCHAT_CONTROLLER,
|
FSCHAT_CONTROLLER,
|
||||||
@ -32,13 +33,21 @@ from configs import (
|
|||||||
HTTPX_DEFAULT_TIMEOUT,
|
HTTPX_DEFAULT_TIMEOUT,
|
||||||
)
|
)
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||||
fschat_openai_api_address, get_httpx_client, get_model_worker_config,
|
fschat_openai_api_address, get_httpx_client,
|
||||||
|
get_model_worker_config,
|
||||||
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
||||||
from server.knowledge_base.migrate import create_tables
|
from server.knowledge_base.migrate import create_tables
|
||||||
import argparse
|
import argparse
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from configs import VERSION
|
from configs import VERSION
|
||||||
|
|
||||||
|
all_model_names = set()
|
||||||
|
for model_category in LLM_MODEL_CONFIG.values():
|
||||||
|
for model_name in model_category.keys():
|
||||||
|
if model_name not in all_model_names:
|
||||||
|
all_model_names.add(model_name)
|
||||||
|
|
||||||
|
all_model_names_list = list(all_model_names)
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
since="0.3.0",
|
since="0.3.0",
|
||||||
@ -109,9 +118,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
import fastchat.serve.vllm_worker
|
import fastchat.serve.vllm_worker
|
||||||
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
|
|
||||||
args.tokenizer = args.model_path
|
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
||||||
args.tokenizer_mode = 'auto'
|
args.tokenizer_mode = 'auto'
|
||||||
args.trust_remote_code = True
|
args.trust_remote_code = True
|
||||||
args.download_dir = None
|
args.download_dir = None
|
||||||
@ -130,7 +139,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
args.conv_template = None
|
args.conv_template = None
|
||||||
args.limit_worker_concurrency = 5
|
args.limit_worker_concurrency = 5
|
||||||
args.no_register = False
|
args.no_register = False
|
||||||
args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
args.num_gpus = 4 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
||||||
args.engine_use_ray = False
|
args.engine_use_ray = False
|
||||||
args.disable_log_requests = False
|
args.disable_log_requests = False
|
||||||
|
|
||||||
@ -365,7 +374,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
|||||||
|
|
||||||
|
|
||||||
def run_model_worker(
|
def run_model_worker(
|
||||||
model_name: str = LLM_MODELS[0],
|
model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model'])),
|
||||||
controller_address: str = "",
|
controller_address: str = "",
|
||||||
log_level: str = "INFO",
|
log_level: str = "INFO",
|
||||||
q: mp.Queue = None,
|
q: mp.Queue = None,
|
||||||
@ -502,7 +511,7 @@ def parse_args() -> argparse.ArgumentParser:
|
|||||||
"--model-worker",
|
"--model-worker",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="run fastchat's model_worker server with specified model name. "
|
help="run fastchat's model_worker server with specified model name. "
|
||||||
"specify --model-name if not using default LLM_MODELS",
|
"specify --model-name if not using default llm models",
|
||||||
dest="model_worker",
|
dest="model_worker",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -510,7 +519,7 @@ def parse_args() -> argparse.ArgumentParser:
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=LLM_MODELS,
|
default=all_model_names_list,
|
||||||
help="specify model name for model worker. "
|
help="specify model name for model worker. "
|
||||||
"add addition names with space seperated to start multiple model workers.",
|
"add addition names with space seperated to start multiple model workers.",
|
||||||
dest="model_name",
|
dest="model_name",
|
||||||
@ -574,7 +583,7 @@ def dump_server_info(after_start=False, args=None):
|
|||||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||||
print("\n")
|
print("\n")
|
||||||
|
|
||||||
models = LLM_MODELS
|
models = list(LLM_MODEL_CONFIG['llm_model'].keys())
|
||||||
if args and args.model_name:
|
if args and args.model_name:
|
||||||
models = args.model_name
|
models = args.model_name
|
||||||
|
|
||||||
@ -769,17 +778,17 @@ async def start_main_server():
|
|||||||
if p := processes.get("api"):
|
if p := processes.get("api"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
api_started.wait()
|
api_started.wait() # 等待api.py启动完成
|
||||||
|
|
||||||
if p := processes.get("webui"):
|
if p := processes.get("webui"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
webui_started.wait()
|
webui_started.wait() # 等待webui.py启动完成
|
||||||
|
|
||||||
dump_server_info(after_start=True, args=args)
|
dump_server_info(after_start=True, args=args)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cmd = queue.get()
|
cmd = queue.get() # 收到切换模型的消息
|
||||||
e = manager.Event()
|
e = manager.Event()
|
||||||
if isinstance(cmd, list):
|
if isinstance(cmd, list):
|
||||||
model_name, cmd, new_model_name = cmd
|
model_name, cmd, new_model_name = cmd
|
||||||
@ -877,20 +886,5 @@ if __name__ == "__main__":
|
|||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
loop.run_until_complete(start_main_server())
|
loop.run_until_complete(start_main_server())
|
||||||
|
|
||||||
# 服务启动后接口调用示例:
|
|
||||||
# import openai
|
|
||||||
# openai.api_key = "EMPTY" # Not support yet
|
|
||||||
# openai.api_base = "http://localhost:8888/v1"
|
|
||||||
|
|
||||||
# model = "chatglm3-6b"
|
|
||||||
|
|
||||||
# # create a chat completion
|
|
||||||
# completion = openai.ChatCompletion.create(
|
|
||||||
# model=model,
|
|
||||||
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
|
|
||||||
# )
|
|
||||||
# # print the completion
|
|
||||||
# print(completion.choices[0].message.content)
|
|
||||||
|
|||||||
@ -29,15 +29,7 @@ def test_server_configs():
|
|||||||
assert len(configs) > 0
|
assert len(configs) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_list_search_engines():
|
@pytest.mark.parametrize("type", ["llm_chat"])
|
||||||
engines = api.list_search_engines()
|
|
||||||
pprint(engines)
|
|
||||||
|
|
||||||
assert isinstance(engines, list)
|
|
||||||
assert len(engines) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
|
|
||||||
def test_get_prompt_template(type):
|
def test_get_prompt_template(type):
|
||||||
print(f"prompt template for: {type}")
|
print(f"prompt template for: {type}")
|
||||||
template = api.get_prompt_template(type=type)
|
template = api.get_prompt_template(type=type)
|
||||||
|
|||||||
@ -85,29 +85,3 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
|
|||||||
assert "docs" in data and len(data["docs"]) > 0
|
assert "docs" in data and len(data["docs"]) > 0
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
def test_search_engine_chat(api="/chat/search_engine_chat"):
|
|
||||||
global data
|
|
||||||
|
|
||||||
data["query"] = "室温超导最新进展是什么样?"
|
|
||||||
|
|
||||||
url = f"{api_base_url}{api}"
|
|
||||||
for se in ["bing", "duckduckgo"]:
|
|
||||||
data["search_engine_name"] = se
|
|
||||||
dump_input(data, api + f" by {se}")
|
|
||||||
response = requests.post(url, json=data, stream=True)
|
|
||||||
if se == "bing" and not BING_SUBSCRIPTION_KEY:
|
|
||||||
data = response.json()
|
|
||||||
assert data["code"] == 404
|
|
||||||
assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
|
|
||||||
|
|
||||||
print("\n")
|
|
||||||
print("=" * 30 + api + f" by {se} output" + "="*30)
|
|
||||||
for line in response.iter_content(None, decode_unicode=True):
|
|
||||||
data = json.loads(line[6:])
|
|
||||||
if "answer" in data:
|
|
||||||
print(data["answer"], end="", flush=True)
|
|
||||||
assert "docs" in data and len(data["docs"]) > 0
|
|
||||||
pprint(data["docs"])
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
|
|||||||
@ -57,14 +57,3 @@ def test_embeddings(worker):
|
|||||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||||
assert isinstance(embeddings[0][0], float)
|
assert isinstance(embeddings[0][0], float)
|
||||||
print("向量长度:", len(embeddings[0]))
|
print("向量长度:", len(embeddings[0]))
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("worker", workers)
|
|
||||||
# def test_completion(worker):
|
|
||||||
# params = ApiCompletionParams(prompt="五十六个民族")
|
|
||||||
|
|
||||||
# print(f"\completion with {worker} \n")
|
|
||||||
|
|
||||||
# worker_class = get_model_worker_config(worker)["worker_class"]
|
|
||||||
# resp = worker_class().do_completion(params)
|
|
||||||
# pprint(resp)
|
|
||||||
|
|||||||
@ -6,8 +6,7 @@ from datetime import datetime
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, LLM_MODELS,
|
from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
|
||||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
|
||||||
from server.knowledge_base.utils import LOADER_DICT
|
from server.knowledge_base.utils import LOADER_DICT
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
@ -55,6 +54,7 @@ def parse_command(text: str, modal: Modal) -> bool:
|
|||||||
/new {session_name}。如果未提供名称,默认为“会话X”
|
/new {session_name}。如果未提供名称,默认为“会话X”
|
||||||
/del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。
|
/del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。
|
||||||
/clear {session_name}。如果未提供名称,默认清除当前会话
|
/clear {session_name}。如果未提供名称,默认清除当前会话
|
||||||
|
/stop {session_name}。如果未提供名称,默认停止当前会话
|
||||||
/help。查看命令帮助
|
/help。查看命令帮助
|
||||||
返回值:输入的是命令返回True,否则返回False
|
返回值:输入的是命令返回True,否则返回False
|
||||||
'''
|
'''
|
||||||
@ -117,36 +117,38 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
st.write("\n\n".join(cmds))
|
st.write("\n\n".join(cmds))
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
# 多会话
|
|
||||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||||
index = 0
|
index = 0
|
||||||
|
|
||||||
|
tools = list(TOOL_CONFIG.keys())
|
||||||
|
selected_tool_configs = {}
|
||||||
|
|
||||||
|
with st.expander("工具栏"):
|
||||||
|
for tool in tools:
|
||||||
|
is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool)
|
||||||
|
if is_selected:
|
||||||
|
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
||||||
|
|
||||||
if st.session_state.get("cur_conv_name") in conv_names:
|
if st.session_state.get("cur_conv_name") in conv_names:
|
||||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||||
conversation_name = st.selectbox("当前会话:", conv_names, index=index)
|
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
||||||
chat_box.use_chat_name(conversation_name)
|
chat_box.use_chat_name(conversation_name)
|
||||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||||
|
|
||||||
def on_mode_change():
|
# def on_mode_change():
|
||||||
mode = st.session_state.dialogue_mode
|
# mode = st.session_state.dialogue_mode
|
||||||
text = f"已切换到 {mode} 模式。"
|
# text = f"已切换到 {mode} 模式。"
|
||||||
if mode == "知识库问答":
|
# st.toast(text)
|
||||||
cur_kb = st.session_state.get("selected_kb")
|
|
||||||
if cur_kb:
|
|
||||||
text = f"{text} 当前知识库: `{cur_kb}`。"
|
|
||||||
st.toast(text)
|
|
||||||
|
|
||||||
dialogue_modes = ["LLM 对话",
|
# dialogue_modes = ["智能对话",
|
||||||
"知识库问答",
|
# "文件对话",
|
||||||
"文件对话",
|
# ]
|
||||||
"搜索引擎问答",
|
# dialogue_mode = st.selectbox("请选择对话模式:",
|
||||||
"自定义Agent问答",
|
# dialogue_modes,
|
||||||
]
|
# index=0,
|
||||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
# on_change=on_mode_change,
|
||||||
dialogue_modes,
|
# key="dialogue_mode",
|
||||||
index=0,
|
# )
|
||||||
on_change=on_mode_change,
|
|
||||||
key="dialogue_mode",
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_llm_change():
|
def on_llm_change():
|
||||||
if llm_model:
|
if llm_model:
|
||||||
@ -164,7 +166,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
available_models = []
|
available_models = []
|
||||||
config_models = api.list_config_models()
|
config_models = api.list_config_models()
|
||||||
if not is_lite:
|
if not is_lite:
|
||||||
for k, v in config_models.get("local", {}).items():
|
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
|
||||||
if (v.get("model_path_exists")
|
if (v.get("model_path_exists")
|
||||||
and k not in running_models):
|
and k not in running_models):
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
@ -177,103 +179,47 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
index = llm_models.index(cur_llm_model)
|
index = llm_models.index(cur_llm_model)
|
||||||
else:
|
else:
|
||||||
index = 0
|
index = 0
|
||||||
llm_model = st.selectbox("选择LLM模型:",
|
llm_model = st.selectbox("选择LLM模型",
|
||||||
llm_models,
|
llm_models,
|
||||||
index,
|
index,
|
||||||
format_func=llm_model_format_func,
|
format_func=llm_model_format_func,
|
||||||
on_change=on_llm_change,
|
on_change=on_llm_change,
|
||||||
key="llm_model",
|
key="llm_model",
|
||||||
)
|
)
|
||||||
if (st.session_state.get("prev_llm_model") != llm_model
|
|
||||||
and not is_lite
|
|
||||||
and not llm_model in config_models.get("online", {})
|
|
||||||
and not llm_model in config_models.get("langchain", {})
|
|
||||||
and llm_model not in running_models):
|
|
||||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
|
||||||
prev_model = st.session_state.get("prev_llm_model")
|
|
||||||
r = api.change_llm_model(prev_model, llm_model)
|
|
||||||
if msg := check_error_msg(r):
|
|
||||||
st.error(msg)
|
|
||||||
elif msg := check_success_msg(r):
|
|
||||||
st.success(msg)
|
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
|
||||||
|
|
||||||
index_prompt = {
|
|
||||||
"LLM 对话": "llm_chat",
|
|
||||||
"自定义Agent问答": "agent_chat",
|
|
||||||
"搜索引擎问答": "search_engine_chat",
|
|
||||||
"知识库问答": "knowledge_base_chat",
|
|
||||||
"文件对话": "knowledge_base_chat",
|
|
||||||
}
|
|
||||||
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
|
||||||
prompt_template_name = prompt_templates_kb_list[0]
|
|
||||||
if "prompt_template_select" not in st.session_state:
|
|
||||||
st.session_state.prompt_template_select = prompt_templates_kb_list[0]
|
|
||||||
|
|
||||||
def prompt_change():
|
# 传入后端的内容
|
||||||
text = f"已切换为 {prompt_template_name} 模板。"
|
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||||
st.toast(text)
|
|
||||||
|
|
||||||
prompt_template_select = st.selectbox(
|
for key in LLM_MODEL_CONFIG:
|
||||||
"请选择Prompt模板:",
|
if key == 'llm_model':
|
||||||
prompt_templates_kb_list,
|
continue
|
||||||
index=0,
|
if LLM_MODEL_CONFIG[key]:
|
||||||
on_change=prompt_change,
|
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||||
key="prompt_template_select",
|
model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
||||||
)
|
|
||||||
prompt_template_name = st.session_state.prompt_template_select
|
|
||||||
temperature = st.slider("Temperature:", 0.0, 2.0, TEMPERATURE, 0.05)
|
|
||||||
history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
|
|
||||||
|
|
||||||
def on_kb_change():
|
if llm_model is not None:
|
||||||
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
|
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||||
|
|
||||||
if dialogue_mode == "知识库问答":
|
print(model_config)
|
||||||
with st.expander("知识库配置", True):
|
files = st.file_uploader("上传附件",
|
||||||
kb_list = api.list_knowledge_bases()
|
type=[i for ls in LOADER_DICT.values() for i in ls],
|
||||||
index = 0
|
accept_multiple_files=True)
|
||||||
if DEFAULT_KNOWLEDGE_BASE in kb_list:
|
|
||||||
index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
|
|
||||||
selected_kb = st.selectbox(
|
|
||||||
"请选择知识库:",
|
|
||||||
kb_list,
|
|
||||||
index=index,
|
|
||||||
on_change=on_kb_change,
|
|
||||||
key="selected_kb",
|
|
||||||
)
|
|
||||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
|
||||||
|
|
||||||
## Bge 模型会超过1
|
|
||||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
|
||||||
elif dialogue_mode == "文件对话":
|
|
||||||
with st.expander("文件对话配置", True):
|
|
||||||
files = st.file_uploader("上传知识文件:",
|
|
||||||
[i for ls in LOADER_DICT.values() for i in ls],
|
|
||||||
accept_multiple_files=True,
|
|
||||||
)
|
|
||||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
|
||||||
|
|
||||||
## Bge 模型会超过1
|
|
||||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
|
||||||
if st.button("开始上传", disabled=len(files) == 0):
|
|
||||||
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
|
||||||
search_engine_list = api.list_search_engines()
|
|
||||||
if DEFAULT_SEARCH_ENGINE in search_engine_list:
|
|
||||||
index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
|
|
||||||
else:
|
|
||||||
index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
|
|
||||||
with st.expander("搜索引擎配置", True):
|
|
||||||
search_engine = st.selectbox(
|
|
||||||
label="请选择搜索引擎",
|
|
||||||
options=search_engine_list,
|
|
||||||
index=index,
|
|
||||||
)
|
|
||||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
|
|
||||||
|
|
||||||
|
# if dialogue_mode == "文件对话":
|
||||||
|
# with st.expander("文件对话配置", True):
|
||||||
|
# files = st.file_uploader("上传知识文件:",
|
||||||
|
# [i for ls in LOADER_DICT.values() for i in ls],
|
||||||
|
# accept_multiple_files=True,
|
||||||
|
# )
|
||||||
|
# kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||||
|
# score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||||
|
# if st.button("开始上传", disabled=len(files) == 0):
|
||||||
|
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
||||||
# Display chat messages from history on app rerun
|
# Display chat messages from history on app rerun
|
||||||
chat_box.output_messages()
|
|
||||||
|
|
||||||
|
|
||||||
|
chat_box.output_messages()
|
||||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||||
|
|
||||||
def on_feedback(
|
def on_feedback(
|
||||||
@ -297,140 +243,76 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
history = get_messages_history(history_len)
|
history = get_messages_history(
|
||||||
|
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
||||||
chat_box.user_say(prompt)
|
chat_box.user_say(prompt)
|
||||||
if dialogue_mode == "LLM 对话":
|
|
||||||
chat_box.ai_say("正在思考...")
|
|
||||||
text = ""
|
|
||||||
message_id = ""
|
|
||||||
r = api.chat_chat(prompt,
|
|
||||||
history=history,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
model=llm_model,
|
|
||||||
prompt_name=prompt_template_name,
|
|
||||||
temperature=temperature)
|
|
||||||
for t in r:
|
|
||||||
if error_msg := check_error_msg(t): # check whether error occured
|
|
||||||
st.error(error_msg)
|
|
||||||
break
|
|
||||||
text += t.get("text", "")
|
|
||||||
chat_box.update_msg(text)
|
|
||||||
message_id = t.get("message_id", "")
|
|
||||||
|
|
||||||
|
chat_box.ai_say("正在思考...")
|
||||||
|
text = ""
|
||||||
|
message_id = ""
|
||||||
|
element_index = 0
|
||||||
|
for d in api.chat_chat(query=prompt,
|
||||||
|
history=history,
|
||||||
|
model_config=model_config,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
tool_config=selected_tool_configs,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
d = json.loads(d)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
message_id = d.get("message_id", "")
|
||||||
metadata = {
|
metadata = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
if error_msg := check_error_msg(d):
|
||||||
chat_box.show_feedback(**feedback_kwargs,
|
st.error(error_msg)
|
||||||
key=message_id,
|
if chunk := d.get("agent_action"):
|
||||||
on_submit=on_feedback,
|
chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete"))
|
||||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
element_index = 1
|
||||||
|
formatted_data = {
|
||||||
elif dialogue_mode == "自定义Agent问答":
|
"action": chunk["tool_name"],
|
||||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
"action_input": chunk["tool_input"]
|
||||||
chat_box.ai_say([
|
}
|
||||||
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n",
|
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"
|
||||||
|
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
|
||||||
])
|
if chunk := d.get("text"):
|
||||||
else:
|
text += chunk
|
||||||
chat_box.ai_say([
|
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
|
||||||
f"正在思考...",
|
if chunk := d.get("agent_finish"):
|
||||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
element_index = 0
|
||||||
|
text = chunk
|
||||||
])
|
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||||
text = ""
|
chat_box.show_feedback(**feedback_kwargs,
|
||||||
ans = ""
|
key=message_id,
|
||||||
for d in api.agent_chat(prompt,
|
on_submit=on_feedback,
|
||||||
history=history,
|
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||||
model=llm_model,
|
|
||||||
prompt_name=prompt_template_name,
|
|
||||||
temperature=temperature,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
d = json.loads(d)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
|
||||||
st.error(error_msg)
|
|
||||||
if chunk := d.get("answer"):
|
|
||||||
text += chunk
|
|
||||||
chat_box.update_msg(text, element_index=1)
|
|
||||||
if chunk := d.get("final_answer"):
|
|
||||||
ans += chunk
|
|
||||||
chat_box.update_msg(ans, element_index=0)
|
|
||||||
if chunk := d.get("tools"):
|
|
||||||
text += "\n\n".join(d.get("tools", []))
|
|
||||||
chat_box.update_msg(text, element_index=1)
|
|
||||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
|
||||||
chat_box.update_msg(text, element_index=1, streaming=False)
|
|
||||||
elif dialogue_mode == "知识库问答":
|
|
||||||
chat_box.ai_say([
|
|
||||||
f"正在查询知识库 `{selected_kb}` ...",
|
|
||||||
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
|
|
||||||
])
|
|
||||||
text = ""
|
|
||||||
for d in api.knowledge_base_chat(prompt,
|
|
||||||
knowledge_base_name=selected_kb,
|
|
||||||
top_k=kb_top_k,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
history=history,
|
|
||||||
model=llm_model,
|
|
||||||
prompt_name=prompt_template_name,
|
|
||||||
temperature=temperature):
|
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
|
||||||
st.error(error_msg)
|
|
||||||
elif chunk := d.get("answer"):
|
|
||||||
text += chunk
|
|
||||||
chat_box.update_msg(text, element_index=0)
|
|
||||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
|
||||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
|
||||||
elif dialogue_mode == "文件对话":
|
|
||||||
if st.session_state["file_chat_id"] is None:
|
|
||||||
st.error("请先上传文件再进行对话")
|
|
||||||
st.stop()
|
|
||||||
chat_box.ai_say([
|
|
||||||
f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
|
||||||
Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
|
||||||
])
|
|
||||||
text = ""
|
|
||||||
for d in api.file_chat(prompt,
|
|
||||||
knowledge_id=st.session_state["file_chat_id"],
|
|
||||||
top_k=kb_top_k,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
history=history,
|
|
||||||
model=llm_model,
|
|
||||||
prompt_name=prompt_template_name,
|
|
||||||
temperature=temperature):
|
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
|
||||||
st.error(error_msg)
|
|
||||||
elif chunk := d.get("answer"):
|
|
||||||
text += chunk
|
|
||||||
chat_box.update_msg(text, element_index=0)
|
|
||||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
|
||||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
|
||||||
chat_box.ai_say([
|
|
||||||
f"正在执行 `{search_engine}` 搜索...",
|
|
||||||
Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
|
|
||||||
])
|
|
||||||
text = ""
|
|
||||||
for d in api.search_engine_chat(prompt,
|
|
||||||
search_engine_name=search_engine,
|
|
||||||
top_k=se_top_k,
|
|
||||||
history=history,
|
|
||||||
model=llm_model,
|
|
||||||
prompt_name=prompt_template_name,
|
|
||||||
temperature=temperature,
|
|
||||||
split_result=se_top_k > 1):
|
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
|
||||||
st.error(error_msg)
|
|
||||||
elif chunk := d.get("answer"):
|
|
||||||
text += chunk
|
|
||||||
chat_box.update_msg(text, element_index=0)
|
|
||||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
|
||||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
|
||||||
|
|
||||||
|
# elif dialogue_mode == "文件对话":
|
||||||
|
# if st.session_state["file_chat_id"] is None:
|
||||||
|
# st.error("请先上传文件再进行对话")
|
||||||
|
# st.stop()
|
||||||
|
# chat_box.ai_say([
|
||||||
|
# f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
||||||
|
# Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
||||||
|
# ])
|
||||||
|
# text = ""
|
||||||
|
# for d in api.file_chat(prompt,
|
||||||
|
# knowledge_id=st.session_state["file_chat_id"],
|
||||||
|
# top_k=kb_top_k,
|
||||||
|
# score_threshold=score_threshold,
|
||||||
|
# history=history,
|
||||||
|
# model=llm_model,
|
||||||
|
# prompt_name=prompt_template_name,
|
||||||
|
# temperature=temperature):
|
||||||
|
# if error_msg := check_error_msg(d):
|
||||||
|
# st.error(error_msg)
|
||||||
|
# elif chunk := d.get("answer"):
|
||||||
|
# text += chunk
|
||||||
|
# chat_box.update_msg(text, element_index=0)
|
||||||
|
# chat_box.update_msg(text, element_index=0, streaming=False)
|
||||||
|
# chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||||
if st.session_state.get("need_rerun"):
|
if st.session_state.get("need_rerun"):
|
||||||
st.session_state["need_rerun"] = False
|
st.session_state["need_rerun"] = False
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|||||||
@ -3,18 +3,15 @@
|
|||||||
|
|
||||||
from typing import *
|
from typing import *
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
|
|
||||||
from configs import (
|
from configs import (
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
DEFAULT_VS_TYPE,
|
DEFAULT_VS_TYPE,
|
||||||
LLM_MODELS,
|
LLM_MODEL_CONFIG,
|
||||||
TEMPERATURE,
|
|
||||||
SCORE_THRESHOLD,
|
SCORE_THRESHOLD,
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
OVERLAP_SIZE,
|
OVERLAP_SIZE,
|
||||||
ZH_TITLE_ENHANCE,
|
ZH_TITLE_ENHANCE,
|
||||||
VECTOR_SEARCH_TOP_K,
|
VECTOR_SEARCH_TOP_K,
|
||||||
SEARCH_ENGINE_TOP_K,
|
|
||||||
HTTPX_DEFAULT_TIMEOUT,
|
HTTPX_DEFAULT_TIMEOUT,
|
||||||
logger, log_verbose,
|
logger, log_verbose,
|
||||||
)
|
)
|
||||||
@ -26,7 +23,6 @@ from io import BytesIO
|
|||||||
from server.utils import set_httpx_config, api_address, get_httpx_client
|
from server.utils import set_httpx_config, api_address, get_httpx_client
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from langchain_core._api import deprecated
|
|
||||||
|
|
||||||
set_httpx_config()
|
set_httpx_config()
|
||||||
|
|
||||||
@ -247,10 +243,6 @@ class ApiRequest:
|
|||||||
response = self.post("/server/configs", **kwargs)
|
response = self.post("/server/configs", **kwargs)
|
||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def list_search_engines(self, **kwargs) -> List:
|
|
||||||
response = self.post("/server/list_search_engines", **kwargs)
|
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
|
|
||||||
|
|
||||||
def get_prompt_template(
|
def get_prompt_template(
|
||||||
self,
|
self,
|
||||||
type: str = "llm_chat",
|
type: str = "llm_chat",
|
||||||
@ -272,10 +264,8 @@ class ApiRequest:
|
|||||||
history_len: int = -1,
|
history_len: int = -1,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model_config: Dict = None,
|
||||||
temperature: float = TEMPERATURE,
|
tool_config: Dict = None,
|
||||||
max_tokens: int = None,
|
|
||||||
prompt_name: str = "default",
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
@ -287,10 +277,8 @@ class ApiRequest:
|
|||||||
"history_len": history_len,
|
"history_len": history_len,
|
||||||
"history": history,
|
"history": history,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"model_name": model,
|
"model_config": model_config,
|
||||||
"temperature": temperature,
|
"tool_config": tool_config,
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"prompt_name": prompt_name,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# print(f"received input message:")
|
# print(f"received input message:")
|
||||||
@ -299,78 +287,6 @@ class ApiRequest:
|
|||||||
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
@deprecated(
|
|
||||||
since="0.3.0",
|
|
||||||
message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
|
|
||||||
removal="0.3.0")
|
|
||||||
def agent_chat(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
history: List[Dict] = [],
|
|
||||||
stream: bool = True,
|
|
||||||
model: str = LLM_MODELS[0],
|
|
||||||
temperature: float = TEMPERATURE,
|
|
||||||
max_tokens: int = None,
|
|
||||||
prompt_name: str = "default",
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
对应api.py/chat/agent_chat 接口
|
|
||||||
'''
|
|
||||||
data = {
|
|
||||||
"query": query,
|
|
||||||
"history": history,
|
|
||||||
"stream": stream,
|
|
||||||
"model_name": model,
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"prompt_name": prompt_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
# print(f"received input message:")
|
|
||||||
# pprint(data)
|
|
||||||
|
|
||||||
response = self.post("/chat/agent_chat", json=data, stream=True)
|
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
|
||||||
|
|
||||||
def knowledge_base_chat(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
knowledge_base_name: str,
|
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
|
||||||
history: List[Dict] = [],
|
|
||||||
stream: bool = True,
|
|
||||||
model: str = LLM_MODELS[0],
|
|
||||||
temperature: float = TEMPERATURE,
|
|
||||||
max_tokens: int = None,
|
|
||||||
prompt_name: str = "default",
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
对应api.py/chat/knowledge_base_chat接口
|
|
||||||
'''
|
|
||||||
data = {
|
|
||||||
"query": query,
|
|
||||||
"knowledge_base_name": knowledge_base_name,
|
|
||||||
"top_k": top_k,
|
|
||||||
"score_threshold": score_threshold,
|
|
||||||
"history": history,
|
|
||||||
"stream": stream,
|
|
||||||
"model_name": model,
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"prompt_name": prompt_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
# print(f"received input message:")
|
|
||||||
# pprint(data)
|
|
||||||
|
|
||||||
response = self.post(
|
|
||||||
"/chat/knowledge_base_chat",
|
|
||||||
json=data,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
|
||||||
|
|
||||||
def upload_temp_docs(
|
def upload_temp_docs(
|
||||||
self,
|
self,
|
||||||
files: List[Union[str, Path, bytes]],
|
files: List[Union[str, Path, bytes]],
|
||||||
@ -416,8 +332,8 @@ class ApiRequest:
|
|||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model: str = None,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = 0.9,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
@ -444,50 +360,6 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
@deprecated(
|
|
||||||
since="0.3.0",
|
|
||||||
message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
|
|
||||||
removal="0.3.0"
|
|
||||||
)
|
|
||||||
def search_engine_chat(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
search_engine_name: str,
|
|
||||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
|
||||||
history: List[Dict] = [],
|
|
||||||
stream: bool = True,
|
|
||||||
model: str = LLM_MODELS[0],
|
|
||||||
temperature: float = TEMPERATURE,
|
|
||||||
max_tokens: int = None,
|
|
||||||
prompt_name: str = "default",
|
|
||||||
split_result: bool = False,
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
对应api.py/chat/search_engine_chat接口
|
|
||||||
'''
|
|
||||||
data = {
|
|
||||||
"query": query,
|
|
||||||
"search_engine_name": search_engine_name,
|
|
||||||
"top_k": top_k,
|
|
||||||
"history": history,
|
|
||||||
"stream": stream,
|
|
||||||
"model_name": model,
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"prompt_name": prompt_name,
|
|
||||||
"split_result": split_result,
|
|
||||||
}
|
|
||||||
|
|
||||||
# print(f"received input message:")
|
|
||||||
# pprint(data)
|
|
||||||
|
|
||||||
response = self.post(
|
|
||||||
"/chat/search_engine_chat",
|
|
||||||
json=data,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
|
||||||
|
|
||||||
# 知识库相关操作
|
# 知识库相关操作
|
||||||
|
|
||||||
def list_knowledge_bases(
|
def list_knowledge_bases(
|
||||||
@ -769,7 +641,7 @@ class ApiRequest:
|
|||||||
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
|
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
|
||||||
'''
|
'''
|
||||||
从服务器上获取当前运行的LLM模型。
|
从服务器上获取当前运行的LLM模型。
|
||||||
当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。
|
当 local_first=True 时,优先返回运行中的本地模型,否则优先按 LLM_MODEL_CONFIG['llm_model']配置顺序返回。
|
||||||
返回类型为(model_name, is_local_model)
|
返回类型为(model_name, is_local_model)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@ -779,7 +651,7 @@ class ApiRequest:
|
|||||||
return "", False
|
return "", False
|
||||||
|
|
||||||
model = ""
|
model = ""
|
||||||
for m in LLM_MODELS:
|
for m in LLM_MODEL_CONFIG['llm_model']:
|
||||||
if m not in running_models:
|
if m not in running_models:
|
||||||
continue
|
continue
|
||||||
is_local = not running_models[m].get("online_api")
|
is_local = not running_models[m].get("online_api")
|
||||||
@ -789,7 +661,7 @@ class ApiRequest:
|
|||||||
model = m
|
model = m
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
|
||||||
model = list(running_models)[0]
|
model = list(running_models)[0]
|
||||||
is_local = not running_models[model].get("online_api")
|
is_local = not running_models[model].get("online_api")
|
||||||
return model, is_local
|
return model, is_local
|
||||||
@ -800,7 +672,7 @@ class ApiRequest:
|
|||||||
return "", False
|
return "", False
|
||||||
|
|
||||||
model = ""
|
model = ""
|
||||||
for m in LLM_MODELS:
|
for m in LLM_MODEL_CONFIG['llm_model']:
|
||||||
if m not in running_models:
|
if m not in running_models:
|
||||||
continue
|
continue
|
||||||
is_local = not running_models[m].get("online_api")
|
is_local = not running_models[m].get("online_api")
|
||||||
@ -810,7 +682,7 @@ class ApiRequest:
|
|||||||
model = m
|
model = m
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
|
||||||
model = list(running_models)[0]
|
model = list(running_models)[0]
|
||||||
is_local = not running_models[model].get("online_api")
|
is_local = not running_models[model].get("online_api")
|
||||||
return model, is_local
|
return model, is_local
|
||||||
@ -852,15 +724,6 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||||
|
|
||||||
def list_search_engines(self) -> List[str]:
|
|
||||||
'''
|
|
||||||
获取服务器支持的搜索引擎
|
|
||||||
'''
|
|
||||||
response = self.post(
|
|
||||||
"/server/list_search_engines",
|
|
||||||
)
|
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
|
||||||
|
|
||||||
def stop_llm_model(
|
def stop_llm_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user