mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-02 20:53:13 +08:00
Dev (#2280)
* 修复Azure 不设置Max token的bug * 重写agent 1. 修改Agent实现方式,支持多参数,仅剩 ChatGLM3-6b和 OpenAI GPT4 支持,剩余模型将在暂时缺席Agent功能 2. 删除agent_chat 集成到llm_chat中 3. 重写大部分工具,适应新Agent * 更新架构 * 删除web_chat,自动融合 * 移除所有聊天,都变成Agent控制 * 更新配置文件 * 更新配置模板和提示词 * 更改参数选择bug
This commit is contained in:
parent
808ca227c5
commit
253168a187
@ -1,22 +0,0 @@
|
||||
from server.utils import get_ChatOpenAI
|
||||
from configs.model_config import LLM_MODELS, TEMPERATURE
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
|
||||
model = get_ChatOpenAI(model_name=LLM_MODELS[0], temperature=TEMPERATURE)
|
||||
|
||||
|
||||
human_prompt = "{input}"
|
||||
human_message_template = HumanMessagePromptTemplate.from_template(human_prompt)
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[("human", "我们来玩成语接龙,我先来,生龙活虎"),
|
||||
("ai", "虎头虎脑"),
|
||||
("human", "{input}")])
|
||||
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model, verbose=True)
|
||||
print(chain({"input": "恼羞成怒"}))
|
||||
@ -1,14 +1,6 @@
|
||||
import os
|
||||
|
||||
# 可以指定一个绝对路径,统一存放所有的Embedding和LLM模型。
|
||||
# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录。
|
||||
# 如果模型目录名称和 MODEL_PATH 中的 key 或 value 相同,程序会自动检测加载,无需修改 MODEL_PATH 中的路径。
|
||||
MODEL_ROOT_PATH = ""
|
||||
|
||||
# 选用的 Embedding 名称
|
||||
EMBEDDING_MODEL = "bge-large-zh-v1.5"
|
||||
|
||||
# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
|
||||
EMBEDDING_MODEL = "bge-large-zh-v1.5" # bge-large-zh
|
||||
EMBEDDING_DEVICE = "auto"
|
||||
|
||||
# 选用的reranker模型
|
||||
@ -21,65 +13,170 @@ RERANKER_MAX_LENGTH = 1024
|
||||
EMBEDDING_KEYWORD_FILE = "keywords.txt"
|
||||
EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
||||
|
||||
# 要运行的 LLM 名称,可以包括本地模型和在线模型。列表中本地模型将在启动项目时全部加载。
|
||||
# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
|
||||
# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
|
||||
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
|
||||
LLM_MODEL_CONFIG = {
|
||||
# 意图识别不需要输出,模型后台知道就行
|
||||
"preprocess_model": {
|
||||
"zhipu-api": {
|
||||
"temperature": 0.4,
|
||||
"max_tokens": 2048,
|
||||
"history_len": 100,
|
||||
"prompt_name": "default",
|
||||
"callbacks": False
|
||||
},
|
||||
},
|
||||
"llm_model": {
|
||||
"chatglm3-6b": {
|
||||
"temperature": 0.9,
|
||||
"max_tokens": 4096,
|
||||
"history_len": 3,
|
||||
"prompt_name": "default",
|
||||
"callbacks": True
|
||||
},
|
||||
"zhipu-api": {
|
||||
"temperature": 0.9,
|
||||
"max_tokens": 4000,
|
||||
"history_len": 5,
|
||||
"prompt_name": "default",
|
||||
"callbacks": True
|
||||
},
|
||||
"Qwen-1_8B-Chat": {
|
||||
"temperature": 0.4,
|
||||
"max_tokens": 2048,
|
||||
"history_len": 100,
|
||||
"prompt_name": "default",
|
||||
"callbacks": False
|
||||
},
|
||||
},
|
||||
"action_model": {
|
||||
"chatglm3-6b": {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 4096,
|
||||
"prompt_name": "ChatGLM3",
|
||||
"callbacks": True
|
||||
},
|
||||
"zhipu-api": {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 2096,
|
||||
"history_len": 5,
|
||||
"prompt_name": "ChatGLM3",
|
||||
"callbacks": True
|
||||
},
|
||||
},
|
||||
"postprocess_model": {
|
||||
"chatglm3-6b": {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 4096,
|
||||
"prompt_name": "default",
|
||||
"callbacks": True
|
||||
}
|
||||
},
|
||||
|
||||
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
|
||||
Agent_MODEL = None
|
||||
}
|
||||
|
||||
TOOL_CONFIG = {
|
||||
"search_local_knowledgebase": {
|
||||
"use": True,
|
||||
"top_k": 10,
|
||||
"score_threshold": 1,
|
||||
"conclude_prompt": {
|
||||
"with_result":
|
||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",'
|
||||
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
"without_result":
|
||||
'请你根据我的提问回答我的问题:\n'
|
||||
'{{ question }}\n'
|
||||
'请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n',
|
||||
}
|
||||
},
|
||||
"search_internet": {
|
||||
"use": True,
|
||||
"search_engine_name": "bing",
|
||||
"search_engine_config":
|
||||
{
|
||||
"bing": {
|
||||
"result_len": 3,
|
||||
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
|
||||
"bing_key": "",
|
||||
},
|
||||
"metaphor": {
|
||||
"result_len": 3,
|
||||
"metaphor_api_key": "",
|
||||
"split_result": False,
|
||||
"chunk_size": 500,
|
||||
"chunk_overlap": 0,
|
||||
},
|
||||
"duckduckgo": {
|
||||
"result_len": 3
|
||||
}
|
||||
},
|
||||
"top_k": 10,
|
||||
"verbose": "Origin",
|
||||
"conclude_prompt":
|
||||
"<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
|
||||
"</指令>\n<已知信息>{{ context }}</已知信息>\n"
|
||||
"<问题>\n"
|
||||
"{{ question }}\n"
|
||||
"</问题>\n"
|
||||
},
|
||||
"arxiv": {
|
||||
"use": True,
|
||||
},
|
||||
"shell": {
|
||||
"use": True,
|
||||
},
|
||||
"weather_check": {
|
||||
"use": True,
|
||||
"api-key": "",
|
||||
},
|
||||
"search_youtube": {
|
||||
"use": False,
|
||||
},
|
||||
"wolfram": {
|
||||
"use": False,
|
||||
},
|
||||
"calculate": {
|
||||
"use": False,
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
|
||||
LLM_DEVICE = "auto"
|
||||
|
||||
HISTORY_LEN = 3
|
||||
|
||||
MAX_TOKENS = 2048
|
||||
|
||||
TEMPERATURE = 0.7
|
||||
|
||||
ONLINE_LLM_MODEL = {
|
||||
"openai-api": {
|
||||
"model_name": "gpt-4",
|
||||
"model_name": "gpt-4-1106-preview",
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"api_key": "",
|
||||
"openai_proxy": "",
|
||||
},
|
||||
|
||||
# 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||
"zhipu-api": {
|
||||
"api_key": "",
|
||||
"version": "glm-4",
|
||||
"version": "chatglm_turbo",
|
||||
"provider": "ChatGLMWorker",
|
||||
},
|
||||
|
||||
# 具体注册及api key获取请前往 https://api.minimax.chat/
|
||||
"minimax-api": {
|
||||
"group_id": "",
|
||||
"api_key": "",
|
||||
"is_pro": False,
|
||||
"provider": "MiniMaxWorker",
|
||||
},
|
||||
|
||||
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
|
||||
"xinghuo-api": {
|
||||
"APPID": "",
|
||||
"APISecret": "",
|
||||
"api_key": "",
|
||||
"version": "v3.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.5","v3.0", "v2.0", "v1.5"
|
||||
"version": "v3.0",
|
||||
"provider": "XingHuoWorker",
|
||||
},
|
||||
|
||||
# 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
|
||||
"qianfan-api": {
|
||||
"version": "ERNIE-Bot", # 注意大小写。当前支持 "ERNIE-Bot" 或 "ERNIE-Bot-turbo", 更多的见官方文档。
|
||||
"version_url": "", # 也可以不填写version,直接填写在千帆申请模型发布的API地址
|
||||
"version": "ernie-bot-4",
|
||||
"version_url": "",
|
||||
"api_key": "",
|
||||
"secret_key": "",
|
||||
"provider": "QianFanWorker",
|
||||
},
|
||||
|
||||
# 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379
|
||||
"fangzhou-api": {
|
||||
"version": "chatglm-6b-model",
|
||||
"version_url": "",
|
||||
@ -87,28 +184,22 @@ ONLINE_LLM_MODEL = {
|
||||
"secret_key": "",
|
||||
"provider": "FangZhouWorker",
|
||||
},
|
||||
|
||||
# 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||
"qwen-api": {
|
||||
"version": "qwen-max",
|
||||
"api_key": "",
|
||||
"provider": "QwenWorker",
|
||||
"embed_model": "text-embedding-v1" # embedding 模型名称
|
||||
},
|
||||
|
||||
# 百川 API,申请方式请参考 https://www.baichuan-ai.com/home#api-enter
|
||||
"baichuan-api": {
|
||||
"version": "Baichuan2-53B",
|
||||
"api_key": "",
|
||||
"secret_key": "",
|
||||
"provider": "BaiChuanWorker",
|
||||
},
|
||||
|
||||
# Azure API
|
||||
"azure-api": {
|
||||
"deployment_name": "", # 部署容器的名字
|
||||
"resource_name": "", # https://{resource_name}.openai.azure.com/openai/ 填写resource_name的部分,其他部分不要填写
|
||||
"api_version": "", # API的版本,不是模型版本
|
||||
"deployment_name": "",
|
||||
"resource_name": "",
|
||||
"api_version": "2023-07-01-preview",
|
||||
"api_key": "",
|
||||
"provider": "AzureWorker",
|
||||
},
|
||||
@ -153,50 +244,33 @@ MODEL_PATH = {
|
||||
|
||||
"bge-small-zh": "BAAI/bge-small-zh",
|
||||
"bge-base-zh": "BAAI/bge-base-zh",
|
||||
"bge-large-zh": "BAAI/bge-large-zh",
|
||||
"bge-large-zh": "/media/zr/Data/Models/Embedding/bge-large-zh",
|
||||
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
|
||||
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
|
||||
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
|
||||
|
||||
"bge-m3": "BAAI/bge-m3",
|
||||
|
||||
"bge-large-zh-v1.5": "/Models/bge-large-zh-v1.5",
|
||||
"piccolo-base-zh": "sensenova/piccolo-base-zh",
|
||||
"piccolo-large-zh": "sensenova/piccolo-large-zh",
|
||||
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
|
||||
"text-embedding-ada-002": "your OPENAI_API_KEY",
|
||||
"nlp_gte_sentence-embedding_chinese-large": "/Models/nlp_gte_sentence-embedding_chinese-large",
|
||||
"text-embedding-ada-002": "Just write your OpenAI key like "sk-o3IGBhC9g8AiFvTGWVKsT*****" ",
|
||||
},
|
||||
|
||||
"llm_model": {
|
||||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||
"chatglm3-6b": "THUDM/chatglm3-6b",
|
||||
"chatglm3-6b": "/Models/chatglm3-6b",
|
||||
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
||||
|
||||
"Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat",
|
||||
"Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||
"Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat",
|
||||
"Yi-34B-Chat": "/data/share/models/Yi-34B-Chat",
|
||||
"BlueLM-7B-Chat": "/Models/BlueLM-7B-Chat",
|
||||
|
||||
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"baichuan2-13b": "/media/zr/Data/Models/LLM/Baichuan2-13B-Chat",
|
||||
"baichuan2-7b": "/media/zr/Data/Models/LLM/Baichuan2-7B-Chat",
|
||||
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||
"baichuan-13b": "baichuan-inc/Baichuan-13B",
|
||||
'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
|
||||
|
||||
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
||||
|
||||
# Qwen1.5 模型 VLLM可能出现问题
|
||||
"Qwen1.5-0.5B-Chat": "Qwen/Qwen1.5-0.5B-Chat",
|
||||
"Qwen1.5-1.8B-Chat": "Qwen/Qwen1.5-1.8B-Chat",
|
||||
"Qwen1.5-4B-Chat": "Qwen/Qwen1.5-4B-Chat",
|
||||
"Qwen1.5-7B-Chat": "Qwen/Qwen1.5-7B-Chat",
|
||||
"Qwen1.5-14B-Chat": "Qwen/Qwen1.5-14B-Chat",
|
||||
"Qwen1.5-72B-Chat": "Qwen/Qwen1.5-72B-Chat",
|
||||
|
||||
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
|
||||
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||
"baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||
"baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat",
|
||||
"aquila-7b": "BAAI/Aquila-7B",
|
||||
"aquilachat-7b": "BAAI/AquilaChat-7B",
|
||||
|
||||
"internlm-7b": "internlm/internlm-7b",
|
||||
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
||||
@ -235,42 +309,43 @@ MODEL_PATH = {
|
||||
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
||||
"dolly-v2-12b": "databricks/dolly-v2-12b",
|
||||
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
|
||||
|
||||
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
|
||||
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
|
||||
"open_llama_13b": "openlm-research/open_llama_13b",
|
||||
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
|
||||
"koala": "young-geng/koala",
|
||||
|
||||
"mpt-7b": "mosaicml/mpt-7b",
|
||||
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
|
||||
"mpt-30b": "mosaicml/mpt-30b",
|
||||
"opt-66b": "facebook/opt-66b",
|
||||
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
||||
|
||||
"Qwen-1_8B-Chat":"Qwen/Qwen-1_8B-Chat"
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8", # 确保已经安装了auto-gptq optimum flash-attn
|
||||
"Qwen-14B-Chat-Int4": "/media/zr/Data/Models/LLM/Qwen-14B-Chat-Int4", # 确保已经安装了auto-gptq optimum flash-attn
|
||||
},
|
||||
|
||||
"reranker": {
|
||||
"bge-reranker-large": "BAAI/bge-reranker-large",
|
||||
"bge-reranker-base": "BAAI/bge-reranker-base",
|
||||
}
|
||||
}
|
||||
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||
|
||||
# 使用VLLM可能导致模型推理能力下降,无法完成Agent任务
|
||||
VLLM_MODEL_DICT = {
|
||||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||
"chatglm3-6b": "THUDM/chatglm3-6b",
|
||||
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
||||
"aquila-7b": "BAAI/Aquila-7B",
|
||||
"aquilachat-7b": "BAAI/AquilaChat-7B",
|
||||
|
||||
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||
"baichuan-13b": "baichuan-inc/Baichuan-13B",
|
||||
'baichuan-13b-chat': 'baichuan-inc/Baichuan-13B-Chat',
|
||||
|
||||
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
||||
|
||||
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
|
||||
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||
"baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
|
||||
"baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||
|
||||
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
|
||||
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
|
||||
'chatglm2-6b': 'THUDM/chatglm2-6b',
|
||||
'chatglm2-6b-32k': 'THUDM/chatglm2-6b-32k',
|
||||
'chatglm3-6b': 'THUDM/chatglm3-6b',
|
||||
'chatglm3-6b-32k': 'THUDM/chatglm3-6b-32k',
|
||||
|
||||
"internlm-7b": "internlm/internlm-7b",
|
||||
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
||||
@ -301,14 +376,13 @@ VLLM_MODEL_DICT = {
|
||||
"opt-66b": "facebook/opt-66b",
|
||||
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
||||
|
||||
}
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||
|
||||
SUPPORT_AGENT_MODEL = [
|
||||
"openai-api", # GPT4 模型
|
||||
"qwen-api", # Qwen Max模型
|
||||
"zhipu-api", # 智谱AI GLM4模型
|
||||
"Qwen", # 所有Qwen系列本地模型
|
||||
"chatglm3-6b",
|
||||
"internlm2-chat-20b",
|
||||
"Orion-14B-Chat-Plugin",
|
||||
]
|
||||
"agentlm-7b": "THUDM/agentlm-7b",
|
||||
"agentlm-13b": "THUDM/agentlm-13b",
|
||||
"agentlm-70b": "THUDM/agentlm-70b",
|
||||
|
||||
}
|
||||
|
||||
@ -1,26 +1,19 @@
|
||||
# prompt模板使用Jinja2语法,简单点就是用双大括号代替f-string的单大括号
|
||||
# 本配置文件支持热加载,修改prompt模板后无需重启服务。
|
||||
|
||||
# LLM对话支持的变量:
|
||||
# - input: 用户输入内容
|
||||
|
||||
# 知识库和搜索引擎对话支持的变量:
|
||||
# - context: 从检索结果拼接的知识文本
|
||||
# - question: 用户提出的问题
|
||||
|
||||
# Agent对话支持的变量:
|
||||
|
||||
# - tools: 可用的工具列表
|
||||
# - tool_names: 可用的工具名称列表
|
||||
# - history: 用户和Agent的对话历史
|
||||
# - input: 用户输入内容
|
||||
# - agent_scratchpad: Agent的思维记录
|
||||
|
||||
PROMPT_TEMPLATES = {
|
||||
"llm_chat": {
|
||||
"preprocess_model": {
|
||||
"default":
|
||||
'请你根据我的描述和我们对话的历史,来判断本次跟我交流是否需要使用工具,还是可以直接凭借你的知识或者历史记录跟我对话。你只要回答一个数字。1 或者 0,1代表需要使用工具,0代表不需要使用工具。\n'
|
||||
'以下几种情况要使用工具,请返回1\n'
|
||||
'1. 实时性的问题,例如天气,日期,地点等信息\n'
|
||||
'2. 需要数学计算的问题\n'
|
||||
'3. 需要查询数据,地点等精确数据\n'
|
||||
'4. 需要行业知识的问题\n'
|
||||
'<question>'
|
||||
'{input}'
|
||||
'</question>'
|
||||
},
|
||||
"llm_model": {
|
||||
"default":
|
||||
'{{ input }}',
|
||||
|
||||
"with_history":
|
||||
'The following is a friendly conversation between a human and an AI. '
|
||||
'The AI is talkative and provides lots of specific details from its context. '
|
||||
@ -29,72 +22,42 @@ PROMPT_TEMPLATES = {
|
||||
'{history}\n'
|
||||
'Human: {input}\n'
|
||||
'AI:',
|
||||
|
||||
"py":
|
||||
'你是一个聪明的代码助手,请你给我写出简单的py代码。 \n'
|
||||
'{{ input }}',
|
||||
},
|
||||
|
||||
|
||||
"knowledge_base_chat": {
|
||||
"default":
|
||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,'
|
||||
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
|
||||
"text":
|
||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
|
||||
"empty": # 搜不到知识库的时候使用
|
||||
'请你回答我的问题:\n'
|
||||
'{{ question }}\n\n',
|
||||
},
|
||||
|
||||
|
||||
"search_engine_chat": {
|
||||
"default":
|
||||
'<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。'
|
||||
'如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 </指令>\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
|
||||
"search":
|
||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>\n'
|
||||
'<已知信息>{{ context }}</已知信息>\n'
|
||||
'<问题>{{ question }}</问题>\n',
|
||||
},
|
||||
|
||||
|
||||
"agent_chat": {
|
||||
"default":
|
||||
'Answer the following questions as best you can. If it is in order, you can use some tools appropriately. '
|
||||
'You have access to the following tools:\n\n'
|
||||
'{tools}\n\n'
|
||||
'Use the following format:\n'
|
||||
'Question: the input question you must answer1\n'
|
||||
'Thought: you should always think about what to do and what tools to use.\n'
|
||||
'Action: the action to take, should be one of [{tool_names}]\n'
|
||||
'Action Input: the input to the action\n'
|
||||
"action_model": {
|
||||
"GPT-4":
|
||||
'Answer the following questions as best you can. You have access to the following tools:\n'
|
||||
'The way you use the tools is by specifying a json blob.\n'
|
||||
'Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).\n'
|
||||
'The only values that should be in the "action" field are: {tool_names}\n'
|
||||
'The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:\n'
|
||||
'```\n\n'
|
||||
'{{{{\n'
|
||||
' "action": $TOOL_NAME,\n'
|
||||
' "action_input": $INPUT\n'
|
||||
'}}}}\n'
|
||||
'```\n\n'
|
||||
'ALWAYS use the following format:\n'
|
||||
'Question: the input question you must answer\n'
|
||||
'Thought: you should always think about what to do\n'
|
||||
'Action:\n'
|
||||
'```\n\n'
|
||||
'$JSON_BLOB'
|
||||
'```\n\n'
|
||||
'Observation: the result of the action\n'
|
||||
'... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\n'
|
||||
'... (this Thought/Action/Observation can repeat N times)\n'
|
||||
'Thought: I now know the final answer\n'
|
||||
'Final Answer: the final answer to the original input question\n'
|
||||
'Begin!\n\n'
|
||||
'history: {history}\n\n'
|
||||
'Question: {input}\n\n'
|
||||
'Thought: {agent_scratchpad}\n',
|
||||
|
||||
'Begin! Reminder to always use the exact characters `Final Answer` when responding.\n'
|
||||
'history: {history}\n'
|
||||
'Question:{input}\n'
|
||||
'Thought:{agent_scratchpad}\n',
|
||||
"ChatGLM3":
|
||||
'You can answer using the tools, or answer directly using your knowledge without using the tools. '
|
||||
'Respond to the human as helpfully and accurately as possible.\n'
|
||||
'You can answer using the tools.Respond to the human as helpfully and accurately as possible.\n'
|
||||
'You have access to the following tools:\n'
|
||||
'{tools}\n'
|
||||
'Use a json blob to specify a tool by providing an action key (tool name) '
|
||||
'Use a json blob to specify a tool by providing an action key (tool name)\n'
|
||||
'and an action_input key (tool input).\n'
|
||||
'Valid "action" values: "Final Answer" or [{tool_names}]'
|
||||
'Valid "action" values: "Final Answer" or [{tool_names}]\n'
|
||||
'Provide only ONE action per $JSON_BLOB, as shown:\n\n'
|
||||
'```\n'
|
||||
'{{{{\n'
|
||||
@ -118,10 +81,13 @@ PROMPT_TEMPLATES = {
|
||||
' "action": "Final Answer",\n'
|
||||
' "action_input": "Final response to human"\n'
|
||||
'}}}}\n'
|
||||
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. '
|
||||
'Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary.\n'
|
||||
'Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\n'
|
||||
'history: {history}\n\n'
|
||||
'Question: {input}\n\n'
|
||||
'Thought: {agent_scratchpad}',
|
||||
'Thought: {agent_scratchpad}\n',
|
||||
},
|
||||
"postprocess_model": {
|
||||
"default": "{{input}}",
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ FSCHAT_MODEL_WORKERS = {
|
||||
# "awq_ckpt": None,
|
||||
# "awq_wbits": 16,
|
||||
# "awq_groupsize": -1,
|
||||
# "model_names": LLM_MODELS,
|
||||
# "model_names": None,
|
||||
# "conv_template": None,
|
||||
# "limit_worker_concurrency": 5,
|
||||
# "stream_interval": 2,
|
||||
|
||||
@ -1,29 +0,0 @@
|
||||
|
||||
# 实现基于ES的数据插入、检索、删除、更新
|
||||
```shell
|
||||
author: 唐国梁Tommy
|
||||
e-mail: flytang186@qq.com
|
||||
|
||||
如果遇到任何问题,可以与我联系,我这边部署后服务是没有问题的。
|
||||
```
|
||||
|
||||
## 第1步:ES docker部署
|
||||
```shell
|
||||
docker network create elastic
|
||||
docker run -id --name elasticsearch --net elastic -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "xpack.security.enabled=false" -e "xpack.security.http.ssl.enabled=false" -t docker.elastic.co/elasticsearch/elasticsearch:8.8.2
|
||||
```
|
||||
|
||||
### 第2步:Kibana docker部署
|
||||
**注意:Kibana版本与ES保持一致**
|
||||
```shell
|
||||
docker pull docker.elastic.co/kibana/kibana:{version}
|
||||
docker run --name kibana --net elastic -p 5601:5601 docker.elastic.co/kibana/kibana:{version}
|
||||
```
|
||||
|
||||
### 第3步:核心代码
|
||||
```shell
|
||||
1. 核心代码路径
|
||||
server/knowledge_base/kb_service/es_kb_service.py
|
||||
|
||||
2. 需要在 configs/model_config.py 中 配置 ES参数(IP, PORT)等;
|
||||
```
|
||||
@ -1,11 +1,7 @@
|
||||
torch==2.1.2
|
||||
torchvision==0.16.2
|
||||
torchaudio==2.1.2
|
||||
xformers==0.0.23.post1
|
||||
transformers==4.37.2
|
||||
sentence_transformers==2.2.2
|
||||
langchain==0.0.354
|
||||
langchain-experimental==0.0.47
|
||||
# API requirements
|
||||
|
||||
langchain>=0.0.346
|
||||
langchain-experimental>=0.0.42
|
||||
pydantic==1.10.13
|
||||
fschat==0.2.35
|
||||
openai==1.9.0
|
||||
|
||||
@ -1,11 +1,7 @@
|
||||
torch~=2.1.2
|
||||
torchvision~=0.16.2
|
||||
torchaudio~=2.1.2
|
||||
xformers>=0.0.23.post1
|
||||
transformers==4.37.2
|
||||
sentence_transformers==2.2.2
|
||||
langchain==0.0.354
|
||||
langchain-experimental==0.0.47
|
||||
# API requirements
|
||||
|
||||
langchain>=0.0.346
|
||||
langchain-experimental>=0.0.42
|
||||
pydantic==1.10.13
|
||||
fschat==0.2.35
|
||||
openai~=1.9.0
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
from .model_contain import *
|
||||
from .callbacks import *
|
||||
from .custom_template import *
|
||||
from .tools import *
|
||||
1
server/agent/agent_factory/__init__.py
Normal file
1
server/agent/agent_factory/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .glm3_agent import initialize_glm3_agent
|
||||
@ -3,6 +3,14 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import yaml
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||
import os
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||
@ -22,6 +30,7 @@ from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from pydantic.schema import model_schema
|
||||
|
||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -42,8 +51,9 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||
if "tool_call" in text:
|
||||
action_end = text.find("```")
|
||||
action = text[:action_end].strip()
|
||||
|
||||
params_str_start = text.find("(") + 1
|
||||
params_str_end = text.rfind(")")
|
||||
params_str_end = text.find(")")
|
||||
params_str = text[params_str_start:params_str_end]
|
||||
|
||||
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
||||
@ -225,4 +235,4 @@ def initialize_glm3_agent(
|
||||
memory=memory,
|
||||
tags=tags_,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@ -1,67 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from langchain.agents import Tool, AgentOutputParser
|
||||
from langchain.prompts import StringPromptTemplate
|
||||
from typing import List
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
|
||||
from configs import SUPPORT_AGENT_MODEL
|
||||
from server.agent import model_container
|
||||
class CustomPromptTemplate(StringPromptTemplate):
|
||||
template: str
|
||||
tools: List[Tool]
|
||||
|
||||
def format(self, **kwargs) -> str:
|
||||
intermediate_steps = kwargs.pop("intermediate_steps")
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\nObservation: {observation}\nThought: "
|
||||
kwargs["agent_scratchpad"] = thoughts
|
||||
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
|
||||
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
||||
return self.template.format(**kwargs)
|
||||
|
||||
class CustomOutputParser(AgentOutputParser):
|
||||
begin: bool = False
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.begin = True
|
||||
|
||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
||||
if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
|
||||
self.begin = False
|
||||
stop_words = ["Observation:"]
|
||||
min_index = len(llm_output)
|
||||
for stop_word in stop_words:
|
||||
index = llm_output.find(stop_word)
|
||||
if index != -1 and index < min_index:
|
||||
min_index = index
|
||||
llm_output = llm_output[:min_index]
|
||||
|
||||
if "Final Answer:" in llm_output:
|
||||
self.begin = True
|
||||
return AgentFinish(
|
||||
return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
|
||||
log=llm_output,
|
||||
)
|
||||
parts = llm_output.split("Action:")
|
||||
if len(parts) < 2:
|
||||
return AgentFinish(
|
||||
return_values={"output": f"调用agent工具失败,该回答为大模型自身能力的回答:\n\n `{llm_output}`"},
|
||||
log=llm_output,
|
||||
)
|
||||
|
||||
action = parts[1].split("Action Input:")[0].strip()
|
||||
action_input = parts[1].split("Action Input:")[1].strip()
|
||||
try:
|
||||
ans = AgentAction(
|
||||
tool=action,
|
||||
tool_input=action_input.strip(" ").strip('"'),
|
||||
log=llm_output
|
||||
)
|
||||
return ans
|
||||
except:
|
||||
return AgentFinish(
|
||||
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
||||
log=llm_output,
|
||||
)
|
||||
@ -1,6 +0,0 @@
|
||||
class ModelContainer:
|
||||
def __init__(self):
|
||||
self.MODEL = None
|
||||
self.DATABASE = None
|
||||
|
||||
model_container = ModelContainer()
|
||||
@ -1,11 +0,0 @@
|
||||
## 导入所有的工具类
|
||||
from .search_knowledgebase_simple import search_knowledgebase_simple
|
||||
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
|
||||
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
|
||||
from .calculate import calculate, CalculatorInput
|
||||
from .weather_check import weathercheck, WeatherInput
|
||||
from .shell import shell, ShellInput
|
||||
from .search_internet import search_internet, SearchInternetInput
|
||||
from .wolfram import wolfram, WolframInput
|
||||
from .search_youtube import search_youtube, YoutubeInput
|
||||
from .arxiv import arxiv, ArxivInput
|
||||
@ -1,76 +0,0 @@
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.chains import LLMMathChain
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
|
||||
问题: ${{包含数学问题的问题。}}
|
||||
```text
|
||||
${{解决问题的单行数学表达式}}
|
||||
```
|
||||
...numexpr.evaluate(query)...
|
||||
```output
|
||||
${{运行代码的输出}}
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
这是两个例子:
|
||||
|
||||
问题: 37593 * 67是多少?
|
||||
```text
|
||||
37593 * 67
|
||||
```
|
||||
...numexpr.evaluate("37593 * 67")...
|
||||
```output
|
||||
2518731
|
||||
|
||||
答案: 2518731
|
||||
|
||||
问题: 37593的五次方根是多少?
|
||||
```text
|
||||
37593**(1/5)
|
||||
```
|
||||
...numexpr.evaluate("37593**(1/5)")...
|
||||
```output
|
||||
8.222831614237718
|
||||
|
||||
答案: 8.222831614237718
|
||||
|
||||
|
||||
问题: 2的平方是多少?
|
||||
```text
|
||||
2 ** 2
|
||||
```
|
||||
...numexpr.evaluate("2 ** 2")...
|
||||
```output
|
||||
4
|
||||
|
||||
答案: 4
|
||||
|
||||
|
||||
现在,这是我的问题:
|
||||
问题: {question}
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class CalculatorInput(BaseModel):
|
||||
query: str = Field()
|
||||
|
||||
def calculate(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_math.run(query)
|
||||
return ans
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = calculate("2的三次方")
|
||||
print("答案:",result)
|
||||
|
||||
|
||||
|
||||
@ -1,37 +0,0 @@
|
||||
import json
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
async def search_engine_iter(query: str):
|
||||
response = await search_engine_chat(query=query,
|
||||
search_engine_name="bing", # 这里切换搜索引擎
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01, # Agent 搜索互联网的时候,温度设置为0.01
|
||||
history=[],
|
||||
top_k = VECTOR_SEARCH_TOP_K,
|
||||
max_tokens= MAX_TOKENS,
|
||||
prompt_name = "default",
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents = data["answer"]
|
||||
docs = data["docs"]
|
||||
|
||||
return contents
|
||||
|
||||
def search_internet(query: str):
|
||||
return asyncio.run(search_engine_iter(query))
|
||||
|
||||
class SearchInternetInput(BaseModel):
|
||||
location: str = Field(description="Query for Internet search")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_internet("今天星期几")
|
||||
print("答案:",result)
|
||||
@ -1,287 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict
|
||||
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from typing import List, Any, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="default",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
|
||||
async def search_knowledge_multiple(queries) -> List[str]:
|
||||
# queries 应该是一个包含多个 (database, query) 元组的列表
|
||||
tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
|
||||
combined_results = []
|
||||
for (database, _), result in zip(queries, results):
|
||||
message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
|
||||
combined_results.append(message)
|
||||
|
||||
return combined_results
|
||||
|
||||
|
||||
def search_knowledge(queries) -> str:
|
||||
responses = asyncio.run(search_knowledge_multiple(queries))
|
||||
# 输出每个整合的查询结果
|
||||
contents = ""
|
||||
for response in responses:
|
||||
contents += response + "\n\n"
|
||||
return contents
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
|
||||
|
||||
对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
|
||||
|
||||
例子:
|
||||
|
||||
robotic,机器人男女比例是多少
|
||||
bigdata,大数据的就业情况如何
|
||||
|
||||
|
||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
|
||||
|
||||
|
||||
{database_names}
|
||||
|
||||
你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||
不要输出中文的逗号,不要输出引号。
|
||||
|
||||
Question: ${{用户的问题}}
|
||||
|
||||
```text
|
||||
${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}}
|
||||
|
||||
```output
|
||||
数据库查询的结果
|
||||
|
||||
现在,我们开始作答
|
||||
问题: {question}
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question", "database_names"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class LLMKnowledgeChain(LLMChain):
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
database_names: Dict[str, str] = None
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _evaluate_expression(self, queries) -> str:
|
||||
try:
|
||||
output = search_knowledge(queries)
|
||||
except Exception as e:
|
||||
output = "输入的信息有误或不存在知识库,错误信息如下:\n"
|
||||
return output + str(e)
|
||||
return output
|
||||
|
||||
def _process_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
|
||||
llm_output = llm_output.strip()
|
||||
# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1).strip()
|
||||
cleaned_input_str = (expression.replace("\"", "").replace("“", "").
|
||||
replace("”", "").replace("```", "").strip())
|
||||
lines = cleaned_input_str.split("\n")
|
||||
# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
|
||||
|
||||
try:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
except:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
|
||||
output = self._evaluate_expression(queries)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
return {self.output_key: f"输入的格式不对:\n {llm_output}"}
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
|
||||
expression = text_match.group(1).strip()
|
||||
cleaned_input_str = (
|
||||
expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip())
|
||||
lines = cleaned_input_str.split("\n")
|
||||
try:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
except:
|
||||
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||
await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
|
||||
verbose=self.verbose)
|
||||
|
||||
output = self._evaluate_expression(queries)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
self.database_names = model_container.DATABASE
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = self.llm_chain.predict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return self._process_llm_result(llm_output, _run_manager)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
self.database_names = model_container.DATABASE
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_knowledge_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMKnowledgeChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
def search_knowledgebase_complex(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_knowledge.run(query)
|
||||
return ans
|
||||
|
||||
class KnowledgeSearchInput(BaseModel):
|
||||
location: str = Field(description="The query to be searched")
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
|
||||
print(result)
|
||||
|
||||
# 这是一个正常的切割
|
||||
# queries = [
|
||||
# ("bigdata", "大数据专业的男女比例"),
|
||||
# ("robotic", "机器人专业的优势")
|
||||
# ]
|
||||
# result = search_knowledge(queries)
|
||||
# print(result)
|
||||
@ -1,234 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from typing import List, Any, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
||||
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str):
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="knowledge_base_chat",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents += data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
|
||||
_PROMPT_TEMPLATE = """
|
||||
用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
|
||||
Question: ${{用户的问题}}
|
||||
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
|
||||
|
||||
{database_names}
|
||||
|
||||
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||
```text
|
||||
${{知识库的名称}}
|
||||
```
|
||||
```output
|
||||
数据库查询的结果
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
现在,这是我的问题:
|
||||
问题: {question}
|
||||
|
||||
"""
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["question", "database_names"],
|
||||
template=_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class LLMKnowledgeChain(LLMChain):
|
||||
llm_chain: LLMChain
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
"""[Deprecated] LLM wrapper to use."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||
database_names: Dict[str, str] = model_container.DATABASE
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||
"Please instantiate with llm_chain argument or using the from_llm "
|
||||
"class method."
|
||||
)
|
||||
if "llm_chain" not in values and values["llm"] is not None:
|
||||
prompt = values.get("prompt", PROMPT)
|
||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _evaluate_expression(self, dataset, query) -> str:
|
||||
try:
|
||||
output = asyncio.run(search_knowledge_base_iter(dataset, query))
|
||||
except Exception as e:
|
||||
output = "输入的信息有误或不存在知识库"
|
||||
return output
|
||||
return output
|
||||
|
||||
def _process_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
llm_input: str,
|
||||
run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
database = text_match.group(1).strip()
|
||||
output = self._evaluate_expression(database, llm_input)
|
||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
return {self.output_key: f"输入的格式不对: {llm_output}"}
|
||||
return {self.output_key: answer}
|
||||
|
||||
async def _aprocess_llm_result(
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
if text_match:
|
||||
expression = text_match.group(1)
|
||||
output = self._evaluate_expression(expression)
|
||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||
answer = "Answer: " + output
|
||||
elif llm_output.startswith("Answer:"):
|
||||
answer = llm_output
|
||||
elif "Answer:" in llm_output:
|
||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = self.llm_chain.predict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
database_names=data_formatted_str,
|
||||
question=inputs[self.input_key],
|
||||
stop=["```output"],
|
||||
callbacks=_run_manager.get_child(),
|
||||
)
|
||||
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_knowledge_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> LLMKnowledgeChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
def search_knowledgebase_once(query: str):
|
||||
model = model_container.MODEL
|
||||
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||
ans = llm_knowledge.run(query)
|
||||
return ans
|
||||
|
||||
|
||||
class KnowledgeSearchInput(BaseModel):
|
||||
location: str = Field(description="The query to be searched")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_knowledgebase_once("大数据的男女比例")
|
||||
print(result)
|
||||
@ -1,32 +0,0 @@
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
|
||||
import json
|
||||
import asyncio
|
||||
from server.agent import model_container
|
||||
|
||||
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||
response = await knowledge_base_chat(query=query,
|
||||
knowledge_base_name=database,
|
||||
model_name=model_container.MODEL.model_name,
|
||||
temperature=0.01,
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="knowledge_base_chat",
|
||||
score_threshold=SCORE_THRESHOLD,
|
||||
stream=False)
|
||||
|
||||
contents = ""
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents = data["answer"]
|
||||
docs = data["docs"]
|
||||
return contents
|
||||
|
||||
def search_knowledgebase_simple(query: str):
|
||||
return asyncio.run(search_knowledge_base_iter(query))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_knowledgebase_simple("大数据男女比例")
|
||||
print("答案:",result)
|
||||
8
server/agent/tools_factory/__init__.py
Normal file
8
server/agent/tools_factory/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from .search_local_knowledgebase import search_local_knowledgebase, SearchKnowledgeInput
|
||||
from .calculate import calculate, CalculatorInput
|
||||
from .weather_check import weather_check, WeatherInput
|
||||
from .shell import shell, ShellInput
|
||||
from .search_internet import search_internet, SearchInternetInput
|
||||
from .wolfram import wolfram, WolframInput
|
||||
from .search_youtube import search_youtube, YoutubeInput
|
||||
from .arxiv import arxiv, ArxivInput
|
||||
23
server/agent/tools_factory/calculate.py
Normal file
23
server/agent/tools_factory/calculate.py
Normal file
@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
def calculate(a: float, b: float, operator: str) -> float:
|
||||
if operator == "+":
|
||||
return a + b
|
||||
elif operator == "-":
|
||||
return a - b
|
||||
elif operator == "*":
|
||||
return a * b
|
||||
elif operator == "/":
|
||||
if b != 0:
|
||||
return a / b
|
||||
else:
|
||||
return float('inf') # 防止除以零
|
||||
elif operator == "^":
|
||||
return a ** b
|
||||
else:
|
||||
raise ValueError("Unsupported operator")
|
||||
|
||||
class CalculatorInput(BaseModel):
|
||||
a: float = Field(description="first number")
|
||||
b: float = Field(description="second number")
|
||||
operator: str = Field(description="operator to use (e.g., +, -, *, /, ^)")
|
||||
99
server/agent/tools_factory/search_internet.py
Normal file
99
server/agent/tools_factory/search_internet.py
Normal file
@ -0,0 +1,99 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
||||
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
||||
from configs import TOOL_CONFIG
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from typing import List, Dict
|
||||
from langchain.docstore.document import Document
|
||||
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
||||
from markdownify import markdownify
|
||||
|
||||
|
||||
def bing_search(text, config):
|
||||
search = BingSearchAPIWrapper(bing_subscription_key=config["bing_key"],
|
||||
bing_search_url=config["bing_search_url"])
|
||||
return search.results(text, config["result_len"])
|
||||
|
||||
|
||||
def duckduckgo_search(text, config):
|
||||
search = DuckDuckGoSearchAPIWrapper()
|
||||
return search.results(text, config["result_len"])
|
||||
|
||||
|
||||
def metaphor_search(
|
||||
text: str,
|
||||
config: dict,
|
||||
) -> List[Dict]:
|
||||
from metaphor_python import Metaphor
|
||||
client = Metaphor(config["metaphor_api_key"])
|
||||
search = client.search(text, num_results=config["result_len"], use_autoprompt=True)
|
||||
contents = search.get_contents().contents
|
||||
for x in contents:
|
||||
x.extract = markdownify(x.extract)
|
||||
if config["split_result"]:
|
||||
docs = [Document(page_content=x.extract,
|
||||
metadata={"link": x.url, "title": x.title})
|
||||
for x in contents]
|
||||
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
||||
chunk_size=config["chunk_size"],
|
||||
chunk_overlap=config["chunk_overlap"])
|
||||
splitted_docs = text_splitter.split_documents(docs)
|
||||
if len(splitted_docs) > config["result_len"]:
|
||||
normal = NormalizedLevenshtein()
|
||||
for x in splitted_docs:
|
||||
x.metadata["score"] = normal.similarity(text, x.page_content)
|
||||
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||||
splitted_docs = splitted_docs[:config["result_len"]]
|
||||
|
||||
docs = [{"snippet": x.page_content,
|
||||
"link": x.metadata["link"],
|
||||
"title": x.metadata["title"]}
|
||||
for x in splitted_docs]
|
||||
else:
|
||||
docs = [{"snippet": x.extract,
|
||||
"link": x.url,
|
||||
"title": x.title}
|
||||
for x in contents]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
SEARCH_ENGINES = {"bing": bing_search,
|
||||
"duckduckgo": duckduckgo_search,
|
||||
"metaphor": metaphor_search,
|
||||
}
|
||||
|
||||
|
||||
def search_result2docs(search_results):
|
||||
docs = []
|
||||
for result in search_results:
|
||||
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
|
||||
metadata={"source": result["link"] if "link" in result.keys() else "",
|
||||
"filename": result["title"] if "title" in result.keys() else ""})
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
def search_engine(query: str,
|
||||
config: dict):
|
||||
search_engine_use = SEARCH_ENGINES[config["search_engine_name"]]
|
||||
results = search_engine_use(text=query,
|
||||
config=config["search_engine_config"][
|
||||
config["search_engine_name"]])
|
||||
docs = search_result2docs(results)
|
||||
context = ""
|
||||
docs = [
|
||||
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||
for inum, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
for doc in docs:
|
||||
context += doc + "\n"
|
||||
return context
|
||||
def search_internet(query: str):
|
||||
tool_config = TOOL_CONFIG["search_internet"]
|
||||
return search_engine(query=query, config=tool_config)
|
||||
|
||||
class SearchInternetInput(BaseModel):
|
||||
query: str = Field(description="Query for Internet search")
|
||||
|
||||
40
server/agent/tools_factory/search_local_knowledgebase.py
Normal file
40
server/agent/tools_factory/search_local_knowledgebase.py
Normal file
@ -0,0 +1,40 @@
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from configs import TOOL_CONFIG
|
||||
|
||||
|
||||
def search_knowledgebase(query: str, database: str, config: dict):
|
||||
docs = search_docs(
|
||||
query=query,
|
||||
knowledge_base_name=database,
|
||||
top_k=config["top_k"],
|
||||
score_threshold=config["score_threshold"])
|
||||
context = ""
|
||||
source_documents = []
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = doc.metadata.get("source")
|
||||
parameters = urlencode({"knowledge_base_name": database, "file_name": filename})
|
||||
url = f"download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
if len(source_documents) == 0:
|
||||
context= "没有找到相关文档,请更换关键词重试"
|
||||
else:
|
||||
for doc in source_documents:
|
||||
context += doc + "\n"
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class SearchKnowledgeInput(BaseModel):
|
||||
database: str = Field(description="Database for Knowledge Search")
|
||||
query: str = Field(description="Query for Knowledge Search")
|
||||
|
||||
|
||||
def search_local_knowledgebase(database: str, query: str):
|
||||
tool_config = TOOL_CONFIG["search_local_knowledgebase"]
|
||||
return search_knowledgebase(query=query, database=database, config=tool_config)
|
||||
@ -6,4 +6,4 @@ def search_youtube(query: str):
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
class YoutubeInput(BaseModel):
|
||||
location: str = Field(description="Query for Videos search")
|
||||
query: str = Field(description="Query for Videos search")
|
||||
@ -6,4 +6,4 @@ def shell(query: str):
|
||||
return tool.run(tool_input=query)
|
||||
|
||||
class ShellInput(BaseModel):
|
||||
query: str = Field(description="一个能在Linux命令行运行的Shell命令")
|
||||
query: str = Field(description="The command to execute")
|
||||
59
server/agent/tools_factory/tools_registry.py
Normal file
59
server/agent/tools_factory/tools_registry.py
Normal file
@ -0,0 +1,59 @@
|
||||
from langchain_core.tools import StructuredTool
|
||||
from server.agent.tools_factory import *
|
||||
from configs import KB_INFO
|
||||
|
||||
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool."
|
||||
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
||||
template_knowledge = template.format(KB_info=KB_info_str)
|
||||
|
||||
all_tools = [
|
||||
StructuredTool.from_function(
|
||||
func=calculate,
|
||||
name="calculate",
|
||||
description="Useful for when you need to answer questions about simple calculations",
|
||||
args_schema=CalculatorInput,
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=arxiv,
|
||||
name="arxiv",
|
||||
description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
|
||||
args_schema=ArxivInput,
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=shell,
|
||||
name="shell",
|
||||
description="Use Shell to execute Linux commands",
|
||||
args_schema=ShellInput,
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=wolfram,
|
||||
name="wolfram",
|
||||
description="Useful for when you need to calculate difficult formulas",
|
||||
args_schema=WolframInput,
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=search_youtube,
|
||||
name="search_youtube",
|
||||
description="use this tools_factory to search youtube videos",
|
||||
args_schema=YoutubeInput,
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=weather_check,
|
||||
name="weather_check",
|
||||
description="Use this tool to check the weather at a specific location",
|
||||
args_schema=WeatherInput,
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=search_internet,
|
||||
name="search_internet",
|
||||
description="Use this tool to use bing search engine to search the internet and get information",
|
||||
args_schema=SearchInternetInput,
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=search_local_knowledgebase,
|
||||
name="search_local_knowledgebase",
|
||||
description=template_knowledge,
|
||||
args_schema=SearchKnowledgeInput,
|
||||
),
|
||||
]
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
"""
|
||||
更简单的单参数输入工具实现,用于查询现在天气的情况
|
||||
简单的单参数输入工具实现,用于查询现在天气的情况
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
from configs.kb_config import SENIVERSE_API_KEY
|
||||
|
||||
|
||||
def weather(location: str, api_key: str):
|
||||
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
|
||||
@ -21,9 +19,7 @@ def weather(location: str, api_key: str):
|
||||
f"Failed to retrieve weather: {response.status_code}")
|
||||
|
||||
|
||||
def weathercheck(location: str):
|
||||
return weather(location, SENIVERSE_API_KEY)
|
||||
|
||||
|
||||
def weather_check(location: str):
|
||||
return weather(location, "S8vrB4U_-c5mvAMiK")
|
||||
class WeatherInput(BaseModel):
|
||||
location: str = Field(description="City name,include city and county")
|
||||
location: str = Field(description="City name,include city and county,like '厦门'")
|
||||
@ -8,4 +8,4 @@ def wolfram(query: str):
|
||||
return ans
|
||||
|
||||
class WolframInput(BaseModel):
|
||||
location: str = Field(description="需要运算的具体问题")
|
||||
formula: str = Field(description="The formula to be calculated")
|
||||
@ -1,55 +0,0 @@
|
||||
from langchain.tools import Tool
|
||||
from server.agent.tools import *
|
||||
|
||||
tools = [
|
||||
Tool.from_function(
|
||||
func=calculate,
|
||||
name="calculate",
|
||||
description="Useful for when you need to answer questions about simple calculations",
|
||||
args_schema=CalculatorInput,
|
||||
),
|
||||
Tool.from_function(
|
||||
func=arxiv,
|
||||
name="arxiv",
|
||||
description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
|
||||
args_schema=ArxivInput,
|
||||
),
|
||||
Tool.from_function(
|
||||
func=weathercheck,
|
||||
name="weather_check",
|
||||
description="",
|
||||
args_schema=WeatherInput,
|
||||
),
|
||||
Tool.from_function(
|
||||
func=shell,
|
||||
name="shell",
|
||||
description="Use Shell to execute Linux commands",
|
||||
args_schema=ShellInput,
|
||||
),
|
||||
Tool.from_function(
|
||||
func=search_knowledgebase_complex,
|
||||
name="search_knowledgebase_complex",
|
||||
description="Use Use this tool to search local knowledgebase and get information",
|
||||
args_schema=KnowledgeSearchInput,
|
||||
),
|
||||
Tool.from_function(
|
||||
func=search_internet,
|
||||
name="search_internet",
|
||||
description="Use this tool to use bing search engine to search the internet",
|
||||
args_schema=SearchInternetInput,
|
||||
),
|
||||
Tool.from_function(
|
||||
func=wolfram,
|
||||
name="Wolfram",
|
||||
description="Useful for when you need to calculate difficult formulas",
|
||||
args_schema=WolframInput,
|
||||
),
|
||||
Tool.from_function(
|
||||
func=search_youtube,
|
||||
name="search_youtube",
|
||||
description="use this tools to search youtube videos",
|
||||
args_schema=YoutubeInput,
|
||||
),
|
||||
]
|
||||
|
||||
tool_names = [tool.name for tool in tools]
|
||||
@ -13,13 +13,12 @@ from fastapi import Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat.chat import chat
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
from server.chat.completion import completion
|
||||
from server.chat.feedback import chat_feedback
|
||||
from server.embeddings_api import embed_texts_endpoint
|
||||
from server.llm_api import (list_running_models, list_config_models,
|
||||
change_llm_model, stop_llm_model,
|
||||
get_model_config, list_search_engines)
|
||||
get_model_config)
|
||||
from server.utils import (BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline,
|
||||
get_server_configs, get_prompt_template)
|
||||
from typing import List, Literal
|
||||
@ -63,11 +62,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
summary="与llm模型对话(通过LLMChain)",
|
||||
)(chat)
|
||||
|
||||
app.post("/chat/search_engine_chat",
|
||||
tags=["Chat"],
|
||||
summary="与搜索引擎对话",
|
||||
)(search_engine_chat)
|
||||
|
||||
app.post("/chat/feedback",
|
||||
tags=["Chat"],
|
||||
summary="返回llm模型对话评分",
|
||||
@ -110,16 +104,12 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
summary="获取服务器原始配置信息",
|
||||
)(get_server_configs)
|
||||
|
||||
app.post("/server/list_search_engines",
|
||||
tags=["Server State"],
|
||||
summary="获取服务器支持的搜索引擎",
|
||||
)(list_search_engines)
|
||||
|
||||
@app.post("/server/get_prompt_template",
|
||||
tags=["Server State"],
|
||||
summary="获取服务区配置的 prompt 模板")
|
||||
def get_server_prompt_template(
|
||||
type: Literal["llm_chat", "knowledge_base_chat", "search_engine_chat", "agent_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat,search_engine_chat,agent_chat"),
|
||||
type: Literal["llm_chat", "knowledge_base_chat"]=Body("llm_chat", description="模板类型,可选值:llm_chat,knowledge_base_chat"),
|
||||
name: str = Body("default", description="模板名称"),
|
||||
) -> str:
|
||||
return get_prompt_template(type=type, name=name)
|
||||
@ -139,7 +129,6 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
def mount_knowledge_routes(app: FastAPI):
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
from server.chat.file_chat import upload_temp_docs, file_chat
|
||||
from server.chat.agent_chat import agent_chat
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
@ -155,11 +144,6 @@ def mount_knowledge_routes(app: FastAPI):
|
||||
summary="文件对话"
|
||||
)(file_chat)
|
||||
|
||||
app.post("/chat/agent_chat",
|
||||
tags=["Chat"],
|
||||
summary="与agent对话")(agent_chat)
|
||||
|
||||
# Tag: Knowledge Base Management
|
||||
app.get("/knowledge_base/list_knowledge_bases",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=ListResponse,
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from uuid import UUID
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.schema import AgentFinish, AgentAction
|
||||
from langchain.schema.output import LLMResult
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
def dumps(obj: Dict) -> str:
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
@ -23,7 +21,7 @@ class Status:
|
||||
tool_finish: int = 7
|
||||
|
||||
|
||||
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.queue = asyncio.Queue()
|
||||
@ -31,40 +29,29 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
self.cur_tool = {}
|
||||
self.out = True
|
||||
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
|
||||
# 对于截断不能自理的大模型,我来帮他截断
|
||||
stop_words = ["Observation:", "Thought","\"","(", "\n","\t"]
|
||||
for stop_word in stop_words:
|
||||
index = input_str.find(stop_word)
|
||||
if index != -1:
|
||||
input_str = input_str[:index]
|
||||
break
|
||||
|
||||
self.cur_tool = {
|
||||
"tool_name": serialized["name"],
|
||||
"input_str": input_str,
|
||||
"output_str": "",
|
||||
"status": Status.agent_action,
|
||||
"run_id": run_id.hex,
|
||||
"llm_token": "",
|
||||
"final_answer": "",
|
||||
"error": "",
|
||||
}
|
||||
# print("\nInput Str:",self.cur_tool["input_str"])
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.out = True ## 重置输出
|
||||
self.cur_tool.update(
|
||||
status=Status.tool_finish,
|
||||
output_str=output.replace("Answer:", ""),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
print("on_tool_start")
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
print("on_tool_end")
|
||||
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
@ -73,23 +60,6 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
# if "Action" in token: ## 减少重复输出
|
||||
# before_action = token.split("Action")[0]
|
||||
# self.cur_tool.update(
|
||||
# status=Status.running,
|
||||
# llm_token=before_action + "\n",
|
||||
# )
|
||||
# self.queue.put_nowait(dumps(self.cur_tool))
|
||||
#
|
||||
# self.out = False
|
||||
#
|
||||
# if token and self.out:
|
||||
# self.cur_tool.update(
|
||||
# status=Status.running,
|
||||
# llm_token=token,
|
||||
# )
|
||||
# self.queue.put_nowait(dumps(self.cur_tool))
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
special_tokens = ["Action", "<|observation|>"]
|
||||
for stoken in special_tokens:
|
||||
@ -103,7 +73,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
self.out = False
|
||||
break
|
||||
|
||||
if token and self.out:
|
||||
if token is not None and token != "" and self.out:
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=token,
|
||||
@ -116,16 +86,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.start,
|
||||
@ -136,8 +107,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.complete,
|
||||
llm_token="\n",
|
||||
llm_token="",
|
||||
)
|
||||
self.out = True
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||
@ -147,15 +119,44 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_action,
|
||||
tool_name=action.tool,
|
||||
tool_input=action.tool_input,
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# 返回最终答案
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_finish,
|
||||
final_answer=finish.return_values["output"],
|
||||
agent_finish=finish.return_values["output"],
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
self.cur_tool = {}
|
||||
|
||||
async def aiter(self) -> AsyncIterator[str]:
|
||||
while not self.queue.empty() or not self.done.is_set():
|
||||
done, other = await asyncio.wait(
|
||||
[
|
||||
asyncio.ensure_future(self.queue.get()),
|
||||
asyncio.ensure_future(self.done.wait()),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if other:
|
||||
other.pop().cancel()
|
||||
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
||||
if token_or_done is True:
|
||||
break
|
||||
yield token_or_done
|
||||
@ -1,178 +0,0 @@
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from fastapi import Body
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
||||
from typing import AsyncIterable, Optional, List
|
||||
|
||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||
from server.knowledge_base.kb_service.base import get_kb_details
|
||||
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
||||
from server.agent.tools_select import tools, tool_names
|
||||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
||||
from server.chat.utils import History
|
||||
from server.agent import model_container
|
||||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
||||
|
||||
|
||||
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user", "content": "请使用知识库工具查询今天北京天气"},
|
||||
{"role": "assistant",
|
||||
"content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def agent_chat_iterator(
|
||||
query: str,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODELS[0],
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = CustomAsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
||||
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
||||
|
||||
if Agent_MODEL:
|
||||
model_agent = get_ChatOpenAI(
|
||||
model_name=Agent_MODEL,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
model_container.MODEL = model_agent
|
||||
else:
|
||||
model_container.MODEL = model
|
||||
|
||||
prompt_template = get_prompt_template("agent_chat", prompt_name)
|
||||
prompt_template_agent = CustomPromptTemplate(
|
||||
template=prompt_template,
|
||||
tools=tools,
|
||||
input_variables=["input", "intermediate_steps", "history"]
|
||||
)
|
||||
output_parser = CustomOutputParser()
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
|
||||
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
|
||||
for message in history:
|
||||
if message.role == 'user':
|
||||
memory.chat_memory.add_user_message(message.content)
|
||||
else:
|
||||
memory.chat_memory.add_ai_message(message.content)
|
||||
if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name:
|
||||
agent_executor = initialize_glm3_agent(
|
||||
llm=model,
|
||||
tools=tools,
|
||||
callback_manager=None,
|
||||
prompt=prompt_template,
|
||||
input_variables=["input", "intermediate_steps", "history"],
|
||||
memory=memory,
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
agent = LLMSingleActionAgent(
|
||||
llm_chain=llm_chain,
|
||||
output_parser=output_parser,
|
||||
stop=["\nObservation:", "Observation"],
|
||||
allowed_tools=tool_names,
|
||||
)
|
||||
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
|
||||
tools=tools,
|
||||
verbose=True,
|
||||
memory=memory,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.create_task(wrap_done(
|
||||
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
|
||||
callback.done))
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
if stream:
|
||||
async for chunk in callback.aiter():
|
||||
tools_use = []
|
||||
# Use server-sent-events to stream the response
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||
continue
|
||||
elif data["status"] == Status.error:
|
||||
tools_use.append("\n```\n")
|
||||
tools_use.append("工具名称: " + data["tool_name"])
|
||||
tools_use.append("工具状态: " + "调用失败")
|
||||
tools_use.append("错误信息: " + data["error"])
|
||||
tools_use.append("重新开始尝试")
|
||||
tools_use.append("\n```\n")
|
||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
||||
elif data["status"] == Status.tool_finish:
|
||||
tools_use.append("\n```\n")
|
||||
tools_use.append("工具名称: " + data["tool_name"])
|
||||
tools_use.append("工具状态: " + "调用成功")
|
||||
tools_use.append("工具输入: " + data["input_str"])
|
||||
tools_use.append("工具输出: " + data["output_str"])
|
||||
tools_use.append("\n```\n")
|
||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
||||
elif data["status"] == Status.agent_finish:
|
||||
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
|
||||
else:
|
||||
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
|
||||
|
||||
|
||||
else:
|
||||
answer = ""
|
||||
final_answer = ""
|
||||
async for chunk in callback.aiter():
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||
continue
|
||||
if data["status"] == Status.error:
|
||||
answer += "\n```\n"
|
||||
answer += "工具名称: " + data["tool_name"] + "\n"
|
||||
answer += "工具状态: " + "调用失败" + "\n"
|
||||
answer += "错误信息: " + data["error"] + "\n"
|
||||
answer += "\n```\n"
|
||||
if data["status"] == Status.tool_finish:
|
||||
answer += "\n```\n"
|
||||
answer += "工具名称: " + data["tool_name"] + "\n"
|
||||
answer += "工具状态: " + "调用成功" + "\n"
|
||||
answer += "工具输入: " + data["input_str"] + "\n"
|
||||
answer += "工具输出: " + data["output_str"] + "\n"
|
||||
answer += "\n```\n"
|
||||
if data["status"] == Status.agent_finish:
|
||||
final_answer = data["final_answer"]
|
||||
else:
|
||||
answer += data["llm_token"]
|
||||
|
||||
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return EventSourceResponse(agent_chat_iterator(query=query,
|
||||
history=history,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
)
|
||||
@ -1,20 +1,43 @@
|
||||
from fastapi import Body
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain.agents import initialize_agent, AgentType
|
||||
from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
from server.agent.agent_factory import initialize_glm3_agent
|
||||
from server.agent.tools_factory.tools_registry import all_tools
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
from typing import AsyncIterable, Dict
|
||||
import asyncio
|
||||
import json
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
from server.chat.utils import History
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.utils import get_prompt_template
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
||||
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
||||
|
||||
|
||||
def create_models_from_config(configs: dict = {}, callbacks: list = []):
|
||||
models = {}
|
||||
prompts = {}
|
||||
for model_type, model_configs in configs.items():
|
||||
for model_name, params in model_configs.items():
|
||||
callback = callbacks if params.get('callbacks', False) else None
|
||||
model_instance = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=params.get('temperature', 0.5),
|
||||
max_tokens=params.get('max_tokens', 1000),
|
||||
callbacks=callback
|
||||
)
|
||||
models[model_type] = model_instance
|
||||
prompt_name = params.get('prompt_name', 'default')
|
||||
prompt_template = get_prompt_template(type=model_type, name=prompt_name)
|
||||
prompts[model_type] = prompt_template
|
||||
return models, prompts
|
||||
|
||||
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
@ -28,76 +51,106 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
model_config: Dict = Body({}, description="LLM 模型配置。"),
|
||||
tool_config: Dict = Body({}, description="工具配置"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
nonlocal history, max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
nonlocal history
|
||||
memory = None
|
||||
message_id = None
|
||||
chat_prompt = None
|
||||
|
||||
# 负责保存llm response到message db
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
||||
chat_type="llm_chat",
|
||||
query=query)
|
||||
callbacks.append(conversation_callback)
|
||||
callback = CustomAsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
if conversation_id:
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
if history: # 优先使用前端传入的历史消息
|
||||
if history:
|
||||
history = [History.from_data(h) for h in history]
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
elif conversation_id and history_len > 0: # 前端要求从数据库取历史消息
|
||||
# 使用memory 时必须 prompt 必须含有memory.memory_key 对应的变量
|
||||
prompt = get_prompt_template("llm_chat", "with_history")
|
||||
chat_prompt = PromptTemplate.from_template(prompt)
|
||||
# 根据conversation_id 获取message 列表进而拼凑 memory
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id,
|
||||
llm=model,
|
||||
elif conversation_id and history_len > 0:
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
|
||||
message_limit=history_len)
|
||||
else:
|
||||
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"input": query}),
|
||||
callback.done),
|
||||
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
|
||||
classifier_chain = (
|
||||
PromptTemplate.from_template(prompts["preprocess_model"])
|
||||
| models["preprocess_model"]
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps(
|
||||
{"text": token, "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
if "chatglm3" in models["action_model"].model_name.lower():
|
||||
agent_executor = initialize_glm3_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
prompt=prompts["action_model"],
|
||||
input_variables=["input", "intermediate_steps", "history"],
|
||||
memory=memory,
|
||||
callback_manager=BaseCallbackManager(handlers=callbacks),
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield json.dumps(
|
||||
{"text": answer, "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
|
||||
agent_executor = initialize_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
memory=memory,
|
||||
verbose=True,
|
||||
)
|
||||
branch = RunnableBranch(
|
||||
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||
chain
|
||||
)
|
||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
||||
if stream:
|
||||
async for chunk in callback.aiter():
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.start:
|
||||
continue
|
||||
elif data["status"] == Status.agent_action:
|
||||
tool_info = {
|
||||
"tool_name": data["tool_name"],
|
||||
"tool_input": data["tool_input"]
|
||||
}
|
||||
yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False)
|
||||
elif data["status"] == Status.agent_finish:
|
||||
yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False)
|
||||
else:
|
||||
text = ""
|
||||
agent_finish = ""
|
||||
tool_info = None
|
||||
async for chunk in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.agent_action:
|
||||
tool_info = {
|
||||
"tool_name": data["tool_name"],
|
||||
"tool_input": data["tool_input"]
|
||||
}
|
||||
if data["status"] == Status.agent_finish:
|
||||
agent_finish = data["agent_finish"]
|
||||
else:
|
||||
text += data["llm_token"]
|
||||
if tool_info:
|
||||
yield json.dumps(
|
||||
{"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
yield json.dumps(
|
||||
{"text": text, "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return EventSourceResponse(chat_iterator())
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from fastapi import Body
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import LLM_MODEL_CONFIG
|
||||
from server.utils import wrap_done, get_OpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
@ -14,8 +14,8 @@ from server.utils import get_prompt_template
|
||||
async def completion(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
echo: bool = Body(False, description="除了输出之外,还回显输入"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(1024, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("default",
|
||||
@ -24,7 +24,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||
|
||||
#todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||
async def completion_iterator(query: str,
|
||||
model_name: str = LLM_MODELS[0],
|
||||
model_name: str = None,
|
||||
prompt_name: str = prompt_name,
|
||||
echo: bool = echo,
|
||||
) -> AsyncIterable[str]:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from fastapi import Body, File, Form, UploadFile
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import (VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||
from server.utils import (wrap_done, get_ChatOpenAI,
|
||||
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
|
||||
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
|
||||
@ -15,8 +14,6 @@ from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _parse_files_in_thread(
|
||||
files: List[UploadFile],
|
||||
@ -102,8 +99,8 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
|
||||
@ -1,147 +0,0 @@
|
||||
from fastapi import Body, Request
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from configs import (LLM_MODELS,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SCORE_THRESHOLD,
|
||||
TEMPERATURE,
|
||||
USE_RERANKER,
|
||||
RERANKER_MODEL,
|
||||
RERANKER_MAX_LENGTH,
|
||||
MODEL_PATH)
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable, List, Optional
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from server.reranker.reranker import LangchainReranker
|
||||
from server.utils import embedding_device
|
||||
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(
|
||||
SCORE_THRESHOLD,
|
||||
description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右",
|
||||
ge=0,
|
||||
le=2
|
||||
),
|
||||
history: List[History] = Body(
|
||||
[],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(
|
||||
None,
|
||||
description="限制LLM生成Token数量,默认None代表模型最大值"
|
||||
),
|
||||
prompt_name: str = Body(
|
||||
"default",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
|
||||
),
|
||||
request: Request = None,
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def knowledge_base_chat_iterator(
|
||||
query: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = model_name,
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
docs = await run_in_threadpool(search_docs,
|
||||
query=query,
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold)
|
||||
|
||||
# 加入reranker
|
||||
if USE_RERANKER:
|
||||
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
|
||||
print("-----------------model path------------------")
|
||||
print(reranker_model_path)
|
||||
reranker_model = LangchainReranker(top_n=top_k,
|
||||
device=embedding_device(),
|
||||
max_length=RERANKER_MAX_LENGTH,
|
||||
model_name_or_path=reranker_model_path
|
||||
)
|
||||
print(docs)
|
||||
docs = reranker_model.compress_documents(documents=docs,
|
||||
query=query)
|
||||
print("---------after rerank------------------")
|
||||
print(docs)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||
else:
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
source_documents = []
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = doc.metadata.get("source")
|
||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
|
||||
base_url = request.base_url
|
||||
url = f"{base_url}knowledge_base/download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
|
||||
if len(source_documents) == 0: # 没有找到相关文档
|
||||
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps({"answer": token}, ensure_ascii=False)
|
||||
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield json.dumps({"answer": answer,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))
|
||||
|
||||
@ -1,208 +0,0 @@
|
||||
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
||||
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
||||
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.docstore.document import Document
|
||||
from fastapi import Body
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from sse_starlette import EventSourceResponse
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from server.chat.utils import History
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Optional, Dict
|
||||
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
||||
from markdownify import markdownify
|
||||
|
||||
|
||||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
||||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||
"title": "env info is not found",
|
||||
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
||||
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
|
||||
bing_search_url=BING_SEARCH_URL)
|
||||
return search.results(text, result_len)
|
||||
|
||||
|
||||
def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
||||
search = DuckDuckGoSearchAPIWrapper()
|
||||
return search.results(text, result_len)
|
||||
|
||||
|
||||
def metaphor_search(
|
||||
text: str,
|
||||
result_len: int = SEARCH_ENGINE_TOP_K,
|
||||
split_result: bool = False,
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
) -> List[Dict]:
|
||||
from metaphor_python import Metaphor
|
||||
|
||||
if not METAPHOR_API_KEY:
|
||||
return []
|
||||
|
||||
client = Metaphor(METAPHOR_API_KEY)
|
||||
search = client.search(text, num_results=result_len, use_autoprompt=True)
|
||||
contents = search.get_contents().contents
|
||||
for x in contents:
|
||||
x.extract = markdownify(x.extract)
|
||||
|
||||
# metaphor 返回的内容都是长文本,需要分词再检索
|
||||
if split_result:
|
||||
docs = [Document(page_content=x.extract,
|
||||
metadata={"link": x.url, "title": x.title})
|
||||
for x in contents]
|
||||
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
splitted_docs = text_splitter.split_documents(docs)
|
||||
|
||||
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
||||
if len(splitted_docs) > result_len:
|
||||
normal = NormalizedLevenshtein()
|
||||
for x in splitted_docs:
|
||||
x.metadata["score"] = normal.similarity(text, x.page_content)
|
||||
splitted_docs.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||||
splitted_docs = splitted_docs[:result_len]
|
||||
|
||||
docs = [{"snippet": x.page_content,
|
||||
"link": x.metadata["link"],
|
||||
"title": x.metadata["title"]}
|
||||
for x in splitted_docs]
|
||||
else:
|
||||
docs = [{"snippet": x.extract,
|
||||
"link": x.url,
|
||||
"title": x.title}
|
||||
for x in contents]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
SEARCH_ENGINES = {"bing": bing_search,
|
||||
"duckduckgo": duckduckgo_search,
|
||||
"metaphor": metaphor_search,
|
||||
}
|
||||
|
||||
|
||||
def search_result2docs(search_results):
|
||||
docs = []
|
||||
for result in search_results:
|
||||
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
|
||||
metadata={"source": result["link"] if "link" in result.keys() else "",
|
||||
"filename": result["title"] if "title" in result.keys() else ""})
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
async def lookup_search_engine(
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
split_result: bool = False,
|
||||
):
|
||||
search_engine = SEARCH_ENGINES[search_engine_name]
|
||||
results = await run_in_threadpool(search_engine, query, result_len=top_k, split_result=split_result)
|
||||
docs = search_result2docs(results)
|
||||
return docs
|
||||
|
||||
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None,
|
||||
description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
prompt_name: str = Body("default",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
split_result: bool = Body(False,
|
||||
description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
||||
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
|
||||
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`")
|
||||
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def search_engine_chat_iterator(query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODELS[0],
|
||||
prompt_name: str = prompt_name,
|
||||
) -> AsyncIterable[str]:
|
||||
nonlocal max_tokens
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
if isinstance(max_tokens, int) and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
docs = await lookup_search_engine(query, search_engine_name, top_k, split_result=split_result)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
prompt_template = get_prompt_template("search_engine_chat", prompt_name)
|
||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||
for inum, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
if len(source_documents) == 0: # 没有找到相关资料(不太可能)
|
||||
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps({"answer": token}, ensure_ascii=False)
|
||||
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield json.dumps({"answer": answer,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return EventSourceResponse(search_engine_chat_iterator(query=query,
|
||||
search_engine_name=search_engine_name,
|
||||
top_k=top_k,
|
||||
history=history,
|
||||
model_name=model_name,
|
||||
prompt_name=prompt_name),
|
||||
)
|
||||
@ -9,7 +9,6 @@ class ConversationModel(Base):
|
||||
__tablename__ = 'conversation'
|
||||
id = Column(String(32), primary_key=True, comment='对话框ID')
|
||||
name = Column(String(50), comment='对话框名称')
|
||||
# chat/agent_chat等
|
||||
chat_type = Column(String(50), comment='聊天类型')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ class MessageModel(Base):
|
||||
__tablename__ = 'message'
|
||||
id = Column(String(32), primary_key=True, comment='聊天记录ID')
|
||||
conversation_id = Column(String(32), default=None, index=True, comment='对话框ID')
|
||||
# chat/agent_chat等
|
||||
chat_type = Column(String(50), comment='聊天类型')
|
||||
query = Column(String(4096), comment='用户问题')
|
||||
response = Column(String(4096), comment='模型回答')
|
||||
|
||||
@ -10,7 +10,6 @@ from typing import List, Optional
|
||||
from server.knowledge_base.kb_summary.base import KBSummaryService
|
||||
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
|
||||
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
|
||||
def recreate_summary_vector_store(
|
||||
@ -19,8 +18,8 @@ def recreate_summary_vector_store(
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
):
|
||||
"""
|
||||
@ -100,8 +99,8 @@ def summary_file_to_vector_store(
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
):
|
||||
"""
|
||||
@ -172,8 +171,8 @@ def summary_doc_ids_to_vector_store(
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
model_name: str = Body(None, description="LLM 模型名称。"),
|
||||
temperature: float = Body(0.01, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
|
||||
@ -7,7 +7,6 @@ from configs import (
|
||||
logger,
|
||||
log_verbose,
|
||||
text_splitter_dict,
|
||||
LLM_MODELS,
|
||||
TEXT_SPLITTER_NAME,
|
||||
)
|
||||
import importlib
|
||||
@ -187,10 +186,10 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||
|
||||
|
||||
def make_text_splitter(
|
||||
splitter_name: str = TEXT_SPLITTER_NAME,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
llm_model: str = LLM_MODELS[0],
|
||||
splitter_name,
|
||||
chunk_size,
|
||||
chunk_overlap,
|
||||
llm_model,
|
||||
):
|
||||
"""
|
||||
根据参数获取特定的分词器
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from fastapi import Body
|
||||
from configs import logger, log_verbose, LLM_MODELS, HTTPX_DEFAULT_TIMEOUT
|
||||
from configs import logger, log_verbose, HTTPX_DEFAULT_TIMEOUT
|
||||
from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models,
|
||||
get_httpx_client, get_model_worker_config)
|
||||
from typing import List
|
||||
@ -62,7 +62,7 @@ def get_model_config(
|
||||
|
||||
|
||||
def stop_llm_model(
|
||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODELS[0]]),
|
||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
@ -86,8 +86,8 @@ def stop_llm_model(
|
||||
|
||||
|
||||
def change_llm_model(
|
||||
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODELS[0]]),
|
||||
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODELS[0]]),
|
||||
model_name: str = Body(..., description="当前运行模型"),
|
||||
new_model_name: str = Body(..., description="要切换的新模型"),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
):
|
||||
'''
|
||||
@ -108,9 +108,3 @@ def change_llm_model(
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
|
||||
def list_search_engines() -> BaseResponse:
|
||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
||||
|
||||
return BaseResponse(data=list(SEARCH_ENGINES))
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from fastchat.conversation import Conversation
|
||||
from configs import LOG_PATH, TEMPERATURE
|
||||
from configs import LOG_PATH
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||
@ -63,7 +63,7 @@ class ApiModelParams(ApiConfigParams):
|
||||
deployment_name: Optional[str] = None # for azure
|
||||
resource_name: Optional[str] = None # for azure
|
||||
|
||||
temperature: float = TEMPERATURE
|
||||
temperature: float = 0.9
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
|
||||
|
||||
@ -1,11 +1,6 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from configs import TEMPERATURE
|
||||
from http import HTTPStatus
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
from fastchat import conversation as conv
|
||||
from server.model_workers.base import *
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import List
|
||||
from fastapi import FastAPI
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||
from configs import (LLM_MODEL_CONFIG, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
|
||||
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
||||
import os
|
||||
@ -402,8 +402,8 @@ def fschat_controller_address() -> str:
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
|
||||
if model := get_model_worker_config(model_name):
|
||||
def fschat_model_worker_address(model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model']))) -> str:
|
||||
if model := get_model_worker_config(model_name): # TODO: depends fastchat
|
||||
host = model["host"]
|
||||
if host == "0.0.0.0":
|
||||
host = "127.0.0.1"
|
||||
@ -443,7 +443,7 @@ def webui_address() -> str:
|
||||
def get_prompt_template(type: str, name: str) -> Optional[str]:
|
||||
'''
|
||||
从prompt_config中加载模板内容
|
||||
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
|
||||
type: "llm_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
|
||||
'''
|
||||
|
||||
from configs import prompt_config
|
||||
@ -617,26 +617,6 @@ def get_server_configs() -> Dict:
|
||||
'''
|
||||
获取configs中的原始配置项,供前端使用
|
||||
'''
|
||||
from configs.kb_config import (
|
||||
DEFAULT_KNOWLEDGE_BASE,
|
||||
DEFAULT_SEARCH_ENGINE,
|
||||
DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
SCORE_THRESHOLD,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
ZH_TITLE_ENHANCE,
|
||||
text_splitter_dict,
|
||||
TEXT_SPLITTER_NAME,
|
||||
)
|
||||
from configs.model_config import (
|
||||
LLM_MODELS,
|
||||
HISTORY_LEN,
|
||||
TEMPERATURE,
|
||||
)
|
||||
from configs.prompt_config import PROMPT_TEMPLATES
|
||||
|
||||
_custom = {
|
||||
"controller_address": fschat_controller_address(),
|
||||
"openai_api_address": fschat_openai_api_address(),
|
||||
|
||||
48
startup.py
48
startup.py
@ -8,6 +8,7 @@ from datetime import datetime
|
||||
from pprint import pprint
|
||||
from langchain_core._api import deprecated
|
||||
|
||||
# 设置numexpr最大线程数,默认为CPU核心数
|
||||
try:
|
||||
import numexpr
|
||||
|
||||
@ -21,7 +22,7 @@ from configs import (
|
||||
LOG_PATH,
|
||||
log_verbose,
|
||||
logger,
|
||||
LLM_MODELS,
|
||||
LLM_MODEL_CONFIG,
|
||||
EMBEDDING_MODEL,
|
||||
TEXT_SPLITTER_NAME,
|
||||
FSCHAT_CONTROLLER,
|
||||
@ -32,13 +33,21 @@ from configs import (
|
||||
HTTPX_DEFAULT_TIMEOUT,
|
||||
)
|
||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||
fschat_openai_api_address, get_httpx_client, get_model_worker_config,
|
||||
fschat_openai_api_address, get_httpx_client,
|
||||
get_model_worker_config,
|
||||
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
import argparse
|
||||
from typing import List, Dict
|
||||
from configs import VERSION
|
||||
|
||||
all_model_names = set()
|
||||
for model_category in LLM_MODEL_CONFIG.values():
|
||||
for model_name in model_category.keys():
|
||||
if model_name not in all_model_names:
|
||||
all_model_names.add(model_name)
|
||||
|
||||
all_model_names_list = list(all_model_names)
|
||||
|
||||
@deprecated(
|
||||
since="0.3.0",
|
||||
@ -109,9 +118,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
import fastchat.serve.vllm_worker
|
||||
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
|
||||
args.tokenizer = args.model_path
|
||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
||||
args.tokenizer_mode = 'auto'
|
||||
args.trust_remote_code = True
|
||||
args.download_dir = None
|
||||
@ -130,7 +139,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.no_register = False
|
||||
args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
||||
args.num_gpus = 4 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
||||
args.engine_use_ray = False
|
||||
args.disable_log_requests = False
|
||||
|
||||
@ -365,7 +374,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
||||
|
||||
|
||||
def run_model_worker(
|
||||
model_name: str = LLM_MODELS[0],
|
||||
model_name: str = next(iter(LLM_MODEL_CONFIG['llm_model'])),
|
||||
controller_address: str = "",
|
||||
log_level: str = "INFO",
|
||||
q: mp.Queue = None,
|
||||
@ -502,7 +511,7 @@ def parse_args() -> argparse.ArgumentParser:
|
||||
"--model-worker",
|
||||
action="store_true",
|
||||
help="run fastchat's model_worker server with specified model name. "
|
||||
"specify --model-name if not using default LLM_MODELS",
|
||||
"specify --model-name if not using default llm models",
|
||||
dest="model_worker",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -510,7 +519,7 @@ def parse_args() -> argparse.ArgumentParser:
|
||||
"--model-name",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=LLM_MODELS,
|
||||
default=all_model_names_list,
|
||||
help="specify model name for model worker. "
|
||||
"add addition names with space seperated to start multiple model workers.",
|
||||
dest="model_name",
|
||||
@ -574,7 +583,7 @@ def dump_server_info(after_start=False, args=None):
|
||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||
print("\n")
|
||||
|
||||
models = LLM_MODELS
|
||||
models = list(LLM_MODEL_CONFIG['llm_model'].keys())
|
||||
if args and args.model_name:
|
||||
models = args.model_name
|
||||
|
||||
@ -769,17 +778,17 @@ async def start_main_server():
|
||||
if p := processes.get("api"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
api_started.wait()
|
||||
api_started.wait() # 等待api.py启动完成
|
||||
|
||||
if p := processes.get("webui"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
webui_started.wait()
|
||||
webui_started.wait() # 等待webui.py启动完成
|
||||
|
||||
dump_server_info(after_start=True, args=args)
|
||||
|
||||
while True:
|
||||
cmd = queue.get()
|
||||
cmd = queue.get() # 收到切换模型的消息
|
||||
e = manager.Event()
|
||||
if isinstance(cmd, list):
|
||||
model_name, cmd, new_model_name = cmd
|
||||
@ -877,20 +886,5 @@ if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
loop.run_until_complete(start_main_server())
|
||||
|
||||
# 服务启动后接口调用示例:
|
||||
# import openai
|
||||
# openai.api_key = "EMPTY" # Not support yet
|
||||
# openai.api_base = "http://localhost:8888/v1"
|
||||
|
||||
# model = "chatglm3-6b"
|
||||
|
||||
# # create a chat completion
|
||||
# completion = openai.ChatCompletion.create(
|
||||
# model=model,
|
||||
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
|
||||
# )
|
||||
# # print the completion
|
||||
# print(completion.choices[0].message.content)
|
||||
|
||||
@ -29,15 +29,7 @@ def test_server_configs():
|
||||
assert len(configs) > 0
|
||||
|
||||
|
||||
def test_list_search_engines():
|
||||
engines = api.list_search_engines()
|
||||
pprint(engines)
|
||||
|
||||
assert isinstance(engines, list)
|
||||
assert len(engines) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type", ["llm_chat", "agent_chat"])
|
||||
@pytest.mark.parametrize("type", ["llm_chat"])
|
||||
def test_get_prompt_template(type):
|
||||
print(f"prompt template for: {type}")
|
||||
template = api.get_prompt_template(type=type)
|
||||
|
||||
@ -85,29 +85,3 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||
assert "docs" in data and len(data["docs"]) > 0
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_search_engine_chat(api="/chat/search_engine_chat"):
|
||||
global data
|
||||
|
||||
data["query"] = "室温超导最新进展是什么样?"
|
||||
|
||||
url = f"{api_base_url}{api}"
|
||||
for se in ["bing", "duckduckgo"]:
|
||||
data["search_engine_name"] = se
|
||||
dump_input(data, api + f" by {se}")
|
||||
response = requests.post(url, json=data, stream=True)
|
||||
if se == "bing" and not BING_SUBSCRIPTION_KEY:
|
||||
data = response.json()
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
|
||||
|
||||
print("\n")
|
||||
print("=" * 30 + api + f" by {se} output" + "="*30)
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line[6:])
|
||||
if "answer" in data:
|
||||
print(data["answer"], end="", flush=True)
|
||||
assert "docs" in data and len(data["docs"]) > 0
|
||||
pprint(data["docs"])
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@ -56,15 +56,4 @@ def test_embeddings(worker):
|
||||
assert isinstance(embeddings, list) and len(embeddings) > 0
|
||||
assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0
|
||||
assert isinstance(embeddings[0][0], float)
|
||||
print("向量长度:", len(embeddings[0]))
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("worker", workers)
|
||||
# def test_completion(worker):
|
||||
# params = ApiCompletionParams(prompt="五十六个民族")
|
||||
|
||||
# print(f"\completion with {worker} \n")
|
||||
|
||||
# worker_class = get_model_worker_config(worker)["worker_class"]
|
||||
# resp = worker_class().do_completion(params)
|
||||
# pprint(resp)
|
||||
print("向量长度:", len(embeddings[0]))
|
||||
@ -6,8 +6,7 @@ from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, LLM_MODELS,
|
||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
||||
from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
|
||||
from server.knowledge_base.utils import LOADER_DICT
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
@ -55,6 +54,7 @@ def parse_command(text: str, modal: Modal) -> bool:
|
||||
/new {session_name}。如果未提供名称,默认为“会话X”
|
||||
/del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。
|
||||
/clear {session_name}。如果未提供名称,默认清除当前会话
|
||||
/stop {session_name}。如果未提供名称,默认停止当前会话
|
||||
/help。查看命令帮助
|
||||
返回值:输入的是命令返回True,否则返回False
|
||||
'''
|
||||
@ -117,36 +117,38 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.write("\n\n".join(cmds))
|
||||
|
||||
with st.sidebar:
|
||||
# 多会话
|
||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||
index = 0
|
||||
|
||||
tools = list(TOOL_CONFIG.keys())
|
||||
selected_tool_configs = {}
|
||||
|
||||
with st.expander("工具栏"):
|
||||
for tool in tools:
|
||||
is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool)
|
||||
if is_selected:
|
||||
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
||||
|
||||
if st.session_state.get("cur_conv_name") in conv_names:
|
||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||
conversation_name = st.selectbox("当前会话:", conv_names, index=index)
|
||||
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
||||
chat_box.use_chat_name(conversation_name)
|
||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||
|
||||
def on_mode_change():
|
||||
mode = st.session_state.dialogue_mode
|
||||
text = f"已切换到 {mode} 模式。"
|
||||
if mode == "知识库问答":
|
||||
cur_kb = st.session_state.get("selected_kb")
|
||||
if cur_kb:
|
||||
text = f"{text} 当前知识库: `{cur_kb}`。"
|
||||
st.toast(text)
|
||||
# def on_mode_change():
|
||||
# mode = st.session_state.dialogue_mode
|
||||
# text = f"已切换到 {mode} 模式。"
|
||||
# st.toast(text)
|
||||
|
||||
dialogue_modes = ["LLM 对话",
|
||||
"知识库问答",
|
||||
"文件对话",
|
||||
"搜索引擎问答",
|
||||
"自定义Agent问答",
|
||||
]
|
||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||
dialogue_modes,
|
||||
index=0,
|
||||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
)
|
||||
# dialogue_modes = ["智能对话",
|
||||
# "文件对话",
|
||||
# ]
|
||||
# dialogue_mode = st.selectbox("请选择对话模式:",
|
||||
# dialogue_modes,
|
||||
# index=0,
|
||||
# on_change=on_mode_change,
|
||||
# key="dialogue_mode",
|
||||
# )
|
||||
|
||||
def on_llm_change():
|
||||
if llm_model:
|
||||
@ -164,7 +166,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
available_models = []
|
||||
config_models = api.list_config_models()
|
||||
if not is_lite:
|
||||
for k, v in config_models.get("local", {}).items():
|
||||
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
|
||||
if (v.get("model_path_exists")
|
||||
and k not in running_models):
|
||||
available_models.append(k)
|
||||
@ -177,103 +179,47 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
index = llm_models.index(cur_llm_model)
|
||||
else:
|
||||
index = 0
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
llm_model = st.selectbox("选择LLM模型",
|
||||
llm_models,
|
||||
index,
|
||||
format_func=llm_model_format_func,
|
||||
on_change=on_llm_change,
|
||||
key="llm_model",
|
||||
)
|
||||
if (st.session_state.get("prev_llm_model") != llm_model
|
||||
and not is_lite
|
||||
and not llm_model in config_models.get("online", {})
|
||||
and not llm_model in config_models.get("langchain", {})
|
||||
and llm_model not in running_models):
|
||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||
prev_model = st.session_state.get("prev_llm_model")
|
||||
r = api.change_llm_model(prev_model, llm_model)
|
||||
if msg := check_error_msg(r):
|
||||
st.error(msg)
|
||||
elif msg := check_success_msg(r):
|
||||
st.success(msg)
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
index_prompt = {
|
||||
"LLM 对话": "llm_chat",
|
||||
"自定义Agent问答": "agent_chat",
|
||||
"搜索引擎问答": "search_engine_chat",
|
||||
"知识库问答": "knowledge_base_chat",
|
||||
"文件对话": "knowledge_base_chat",
|
||||
}
|
||||
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
||||
prompt_template_name = prompt_templates_kb_list[0]
|
||||
if "prompt_template_select" not in st.session_state:
|
||||
st.session_state.prompt_template_select = prompt_templates_kb_list[0]
|
||||
|
||||
def prompt_change():
|
||||
text = f"已切换为 {prompt_template_name} 模板。"
|
||||
st.toast(text)
|
||||
# 传入后端的内容
|
||||
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||
|
||||
prompt_template_select = st.selectbox(
|
||||
"请选择Prompt模板:",
|
||||
prompt_templates_kb_list,
|
||||
index=0,
|
||||
on_change=prompt_change,
|
||||
key="prompt_template_select",
|
||||
)
|
||||
prompt_template_name = st.session_state.prompt_template_select
|
||||
temperature = st.slider("Temperature:", 0.0, 2.0, TEMPERATURE, 0.05)
|
||||
history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
|
||||
for key in LLM_MODEL_CONFIG:
|
||||
if key == 'llm_model':
|
||||
continue
|
||||
if LLM_MODEL_CONFIG[key]:
|
||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||
model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
||||
|
||||
def on_kb_change():
|
||||
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
|
||||
if llm_model is not None:
|
||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||
|
||||
if dialogue_mode == "知识库问答":
|
||||
with st.expander("知识库配置", True):
|
||||
kb_list = api.list_knowledge_bases()
|
||||
index = 0
|
||||
if DEFAULT_KNOWLEDGE_BASE in kb_list:
|
||||
index = kb_list.index(DEFAULT_KNOWLEDGE_BASE)
|
||||
selected_kb = st.selectbox(
|
||||
"请选择知识库:",
|
||||
kb_list,
|
||||
index=index,
|
||||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
|
||||
## Bge 模型会超过1
|
||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||
elif dialogue_mode == "文件对话":
|
||||
with st.expander("文件对话配置", True):
|
||||
files = st.file_uploader("上传知识文件:",
|
||||
[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True,
|
||||
)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
|
||||
## Bge 模型会超过1
|
||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||
if st.button("开始上传", disabled=len(files) == 0):
|
||||
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
search_engine_list = api.list_search_engines()
|
||||
if DEFAULT_SEARCH_ENGINE in search_engine_list:
|
||||
index = search_engine_list.index(DEFAULT_SEARCH_ENGINE)
|
||||
else:
|
||||
index = search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0
|
||||
with st.expander("搜索引擎配置", True):
|
||||
search_engine = st.selectbox(
|
||||
label="请选择搜索引擎",
|
||||
options=search_engine_list,
|
||||
index=index,
|
||||
)
|
||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
|
||||
print(model_config)
|
||||
files = st.file_uploader("上传附件",
|
||||
type=[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True)
|
||||
|
||||
# if dialogue_mode == "文件对话":
|
||||
# with st.expander("文件对话配置", True):
|
||||
# files = st.file_uploader("上传知识文件:",
|
||||
# [i for ls in LOADER_DICT.values() for i in ls],
|
||||
# accept_multiple_files=True,
|
||||
# )
|
||||
# kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
# score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||
# if st.button("开始上传", disabled=len(files) == 0):
|
||||
# st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
||||
# Display chat messages from history on app rerun
|
||||
chat_box.output_messages()
|
||||
|
||||
|
||||
chat_box.output_messages()
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||
|
||||
def on_feedback(
|
||||
@ -297,140 +243,76 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
||||
st.rerun()
|
||||
else:
|
||||
history = get_messages_history(history_len)
|
||||
history = get_messages_history(
|
||||
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
||||
chat_box.user_say(prompt)
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
message_id = ""
|
||||
r = api.chat_chat(prompt,
|
||||
history=history,
|
||||
conversation_id=conversation_id,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t.get("text", "")
|
||||
chat_box.update_msg(text)
|
||||
message_id = t.get("message_id", "")
|
||||
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
message_id = ""
|
||||
element_index = 0
|
||||
for d in api.chat_chat(query=prompt,
|
||||
history=history,
|
||||
model_config=model_config,
|
||||
conversation_id=conversation_id,
|
||||
tool_config=selected_tool_configs,
|
||||
):
|
||||
try:
|
||||
d = json.loads(d)
|
||||
except:
|
||||
pass
|
||||
message_id = d.get("message_id", "")
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=message_id,
|
||||
on_submit=on_feedback,
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||
chat_box.ai_say([
|
||||
f"正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
])
|
||||
else:
|
||||
chat_box.ai_say([
|
||||
f"正在思考...",
|
||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||
|
||||
])
|
||||
text = ""
|
||||
ans = ""
|
||||
for d in api.agent_chat(prompt,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
):
|
||||
try:
|
||||
d = json.loads(d)
|
||||
except:
|
||||
pass
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
if chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
if chunk := d.get("final_answer"):
|
||||
ans += chunk
|
||||
chat_box.update_msg(ans, element_index=0)
|
||||
if chunk := d.get("tools"):
|
||||
text += "\n\n".join(d.get("tools", []))
|
||||
chat_box.update_msg(text, element_index=1)
|
||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||
chat_box.update_msg(text, element_index=1, streaming=False)
|
||||
elif dialogue_mode == "知识库问答":
|
||||
chat_box.ai_say([
|
||||
f"正在查询知识库 `{selected_kb}` ...",
|
||||
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.knowledge_base_chat(prompt,
|
||||
knowledge_base_name=selected_kb,
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
elif dialogue_mode == "文件对话":
|
||||
if st.session_state["file_chat_id"] is None:
|
||||
st.error("请先上传文件再进行对话")
|
||||
st.stop()
|
||||
chat_box.ai_say([
|
||||
f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
||||
Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.file_chat(prompt,
|
||||
knowledge_id=st.session_state["file_chat_id"],
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
chat_box.ai_say([
|
||||
f"正在执行 `{search_engine}` 搜索...",
|
||||
Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt,
|
||||
search_engine_name=search_engine,
|
||||
top_k=se_top_k,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
prompt_name=prompt_template_name,
|
||||
temperature=temperature,
|
||||
split_result=se_top_k > 1):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
if error_msg := check_error_msg(d):
|
||||
st.error(error_msg)
|
||||
if chunk := d.get("agent_action"):
|
||||
chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete"))
|
||||
element_index = 1
|
||||
formatted_data = {
|
||||
"action": chunk["tool_name"],
|
||||
"action_input": chunk["tool_input"]
|
||||
}
|
||||
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||
text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"
|
||||
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
|
||||
if chunk := d.get("text"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
|
||||
if chunk := d.get("agent_finish"):
|
||||
element_index = 0
|
||||
text = chunk
|
||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=message_id,
|
||||
on_submit=on_feedback,
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
|
||||
# elif dialogue_mode == "文件对话":
|
||||
# if st.session_state["file_chat_id"] is None:
|
||||
# st.error("请先上传文件再进行对话")
|
||||
# st.stop()
|
||||
# chat_box.ai_say([
|
||||
# f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
||||
# Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
||||
# ])
|
||||
# text = ""
|
||||
# for d in api.file_chat(prompt,
|
||||
# knowledge_id=st.session_state["file_chat_id"],
|
||||
# top_k=kb_top_k,
|
||||
# score_threshold=score_threshold,
|
||||
# history=history,
|
||||
# model=llm_model,
|
||||
# prompt_name=prompt_template_name,
|
||||
# temperature=temperature):
|
||||
# if error_msg := check_error_msg(d):
|
||||
# st.error(error_msg)
|
||||
# elif chunk := d.get("answer"):
|
||||
# text += chunk
|
||||
# chat_box.update_msg(text, element_index=0)
|
||||
# chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
# chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
if st.session_state.get("need_rerun"):
|
||||
st.session_state["need_rerun"] = False
|
||||
st.rerun()
|
||||
|
||||
@ -3,18 +3,15 @@
|
||||
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
DEFAULT_VS_TYPE,
|
||||
LLM_MODELS,
|
||||
TEMPERATURE,
|
||||
LLM_MODEL_CONFIG,
|
||||
SCORE_THRESHOLD,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
HTTPX_DEFAULT_TIMEOUT,
|
||||
logger, log_verbose,
|
||||
)
|
||||
@ -26,7 +23,6 @@ from io import BytesIO
|
||||
from server.utils import set_httpx_config, api_address, get_httpx_client
|
||||
|
||||
from pprint import pprint
|
||||
from langchain_core._api import deprecated
|
||||
|
||||
set_httpx_config()
|
||||
|
||||
@ -247,10 +243,6 @@ class ApiRequest:
|
||||
response = self.post("/server/configs", **kwargs)
|
||||
return self._get_response_value(response, as_json=True)
|
||||
|
||||
def list_search_engines(self, **kwargs) -> List:
|
||||
response = self.post("/server/list_search_engines", **kwargs)
|
||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
|
||||
|
||||
def get_prompt_template(
|
||||
self,
|
||||
type: str = "llm_chat",
|
||||
@ -272,10 +264,8 @@ class ApiRequest:
|
||||
history_len: int = -1,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODELS[0],
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
prompt_name: str = "default",
|
||||
model_config: Dict = None,
|
||||
tool_config: Dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
'''
|
||||
@ -287,10 +277,8 @@ class ApiRequest:
|
||||
"history_len": history_len,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"prompt_name": prompt_name,
|
||||
"model_config": model_config,
|
||||
"tool_config": tool_config,
|
||||
}
|
||||
|
||||
# print(f"received input message:")
|
||||
@ -299,78 +287,6 @@ class ApiRequest:
|
||||
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
@deprecated(
|
||||
since="0.3.0",
|
||||
message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
|
||||
removal="0.3.0")
|
||||
def agent_chat(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODELS[0],
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
prompt_name: str = "default",
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/agent_chat 接口
|
||||
'''
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post("/chat/agent_chat", json=data, stream=True)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
def knowledge_base_chat(
|
||||
self,
|
||||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODELS[0],
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
prompt_name: str = "default",
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/knowledge_base_chat接口
|
||||
'''
|
||||
data = {
|
||||
"query": query,
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"prompt_name": prompt_name,
|
||||
}
|
||||
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/knowledge_base_chat",
|
||||
json=data,
|
||||
stream=True,
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
def upload_temp_docs(
|
||||
self,
|
||||
files: List[Union[str, Path, bytes]],
|
||||
@ -416,8 +332,8 @@ class ApiRequest:
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODELS[0],
|
||||
temperature: float = TEMPERATURE,
|
||||
model: str = None,
|
||||
temperature: float = 0.9,
|
||||
max_tokens: int = None,
|
||||
prompt_name: str = "default",
|
||||
):
|
||||
@ -444,50 +360,6 @@ class ApiRequest:
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
@deprecated(
|
||||
since="0.3.0",
|
||||
message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
|
||||
removal="0.3.0"
|
||||
)
|
||||
def search_engine_chat(
|
||||
self,
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODELS[0],
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = None,
|
||||
prompt_name: str = "default",
|
||||
split_result: bool = False,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/search_engine_chat接口
|
||||
'''
|
||||
data = {
|
||||
"query": query,
|
||||
"search_engine_name": search_engine_name,
|
||||
"top_k": top_k,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"prompt_name": prompt_name,
|
||||
"split_result": split_result,
|
||||
}
|
||||
|
||||
# print(f"received input message:")
|
||||
# pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/search_engine_chat",
|
||||
json=data,
|
||||
stream=True,
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
# 知识库相关操作
|
||||
|
||||
def list_knowledge_bases(
|
||||
@ -769,7 +641,7 @@ class ApiRequest:
|
||||
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
|
||||
'''
|
||||
从服务器上获取当前运行的LLM模型。
|
||||
当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。
|
||||
当 local_first=True 时,优先返回运行中的本地模型,否则优先按 LLM_MODEL_CONFIG['llm_model']配置顺序返回。
|
||||
返回类型为(model_name, is_local_model)
|
||||
'''
|
||||
|
||||
@ -779,7 +651,7 @@ class ApiRequest:
|
||||
return "", False
|
||||
|
||||
model = ""
|
||||
for m in LLM_MODELS:
|
||||
for m in LLM_MODEL_CONFIG['llm_model']:
|
||||
if m not in running_models:
|
||||
continue
|
||||
is_local = not running_models[m].get("online_api")
|
||||
@ -789,7 +661,7 @@ class ApiRequest:
|
||||
model = m
|
||||
break
|
||||
|
||||
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
||||
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
|
||||
model = list(running_models)[0]
|
||||
is_local = not running_models[model].get("online_api")
|
||||
return model, is_local
|
||||
@ -800,7 +672,7 @@ class ApiRequest:
|
||||
return "", False
|
||||
|
||||
model = ""
|
||||
for m in LLM_MODELS:
|
||||
for m in LLM_MODEL_CONFIG['llm_model']:
|
||||
if m not in running_models:
|
||||
continue
|
||||
is_local = not running_models[m].get("online_api")
|
||||
@ -810,7 +682,7 @@ class ApiRequest:
|
||||
model = m
|
||||
break
|
||||
|
||||
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
||||
if not model: # LLM_MODEL_CONFIG['llm_model']中配置的模型都不在running_models里
|
||||
model = list(running_models)[0]
|
||||
is_local = not running_models[model].get("online_api")
|
||||
return model, is_local
|
||||
@ -852,15 +724,6 @@ class ApiRequest:
|
||||
)
|
||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||
|
||||
def list_search_engines(self) -> List[str]:
|
||||
'''
|
||||
获取服务器支持的搜索引擎
|
||||
'''
|
||||
response = self.post(
|
||||
"/server/list_search_engines",
|
||||
)
|
||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||
|
||||
def stop_llm_model(
|
||||
self,
|
||||
model_name: str,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user