mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 06:49:48 +08:00
北京黑客松更新 (#1785)
* 北京黑客松更新 知识库支持: 支持zilliz数据库 Agent支持: 支持以下工具调用 1. 支持互联网Agent调用 2. 支持知识库Agent调用 3. 支持旅游助手工具(未上传) 知识库更新 1. 支持知识库简介,用于Agent选择 2. UI对应知识库简介 提示词选择 1. UI 和模板支持提示词模板更换选择
This commit is contained in:
parent
9ce328fea9
commit
69e5da4e7a
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
# 默认向量库类型。可选:faiss, milvus(离线) & zilliz(在线), pg.
|
||||||
DEFAULT_VS_TYPE = "faiss"
|
DEFAULT_VS_TYPE = "faiss"
|
||||||
|
|
||||||
# 缓存向量库数量(针对FAISS)
|
# 缓存向量库数量(针对FAISS)
|
||||||
@ -42,13 +42,17 @@ BING_SUBSCRIPTION_KEY = ""
|
|||||||
ZH_TITLE_ENHANCE = False
|
ZH_TITLE_ENHANCE = False
|
||||||
|
|
||||||
|
|
||||||
# 通常情况下不需要更改以下内容
|
# 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。
|
||||||
|
KB_INFO = {
|
||||||
|
"知识库名称": "知识库介绍",
|
||||||
|
"samples": "关于本项目issue的解答",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 通常情况下不需要更改以下内容
|
||||||
# 知识库默认存储路径
|
# 知识库默认存储路径
|
||||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
|
||||||
if not os.path.exists(KB_ROOT_PATH):
|
if not os.path.exists(KB_ROOT_PATH):
|
||||||
os.mkdir(KB_ROOT_PATH)
|
os.mkdir(KB_ROOT_PATH)
|
||||||
|
|
||||||
# 数据库默认存储路径。
|
# 数据库默认存储路径。
|
||||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
||||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||||
@ -65,6 +69,13 @@ kbs_config = {
|
|||||||
"password": "",
|
"password": "",
|
||||||
"secure": False,
|
"secure": False,
|
||||||
},
|
},
|
||||||
|
"zilliz": {
|
||||||
|
"host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn",
|
||||||
|
"port": "19530",
|
||||||
|
"user": "",
|
||||||
|
"password": "",
|
||||||
|
"secure": True,
|
||||||
|
},
|
||||||
"pg": {
|
"pg": {
|
||||||
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
|
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
|
||||||
}
|
}
|
||||||
@ -74,11 +85,11 @@ kbs_config = {
|
|||||||
text_splitter_dict = {
|
text_splitter_dict = {
|
||||||
"ChineseRecursiveTextSplitter": {
|
"ChineseRecursiveTextSplitter": {
|
||||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||||
"tokenizer_name_or_path": "gpt2",
|
"tokenizer_name_or_path": "",
|
||||||
},
|
},
|
||||||
"SpacyTextSplitter": {
|
"SpacyTextSplitter": {
|
||||||
"source": "huggingface",
|
"source": "huggingface",
|
||||||
"tokenizer_name_or_path": "",
|
"tokenizer_name_or_path": "gpt2",
|
||||||
},
|
},
|
||||||
"RecursiveCharacterTextSplitter": {
|
"RecursiveCharacterTextSplitter": {
|
||||||
"source": "tiktoken",
|
"source": "tiktoken",
|
||||||
|
|||||||
@ -9,98 +9,106 @@
|
|||||||
# - context: 从检索结果拼接的知识文本
|
# - context: 从检索结果拼接的知识文本
|
||||||
# - question: 用户提出的问题
|
# - question: 用户提出的问题
|
||||||
|
|
||||||
|
# Agent对话支持的变量:
|
||||||
|
|
||||||
PROMPT_TEMPLATES = {
|
# - tools: 可用的工具列表
|
||||||
# LLM对话模板
|
# - tool_names: 可用的工具名称列表
|
||||||
"llm_chat": "{{ input }}",
|
# - history: 用户和Agent的对话历史
|
||||||
|
# - input: 用户输入内容
|
||||||
|
# - agent_scratchpad: Agent的思维记录
|
||||||
|
|
||||||
# 基于本地知识问答的提示词模板
|
PROMPT_TEMPLATES = {}
|
||||||
"knowledge_base_chat":
|
|
||||||
|
PROMPT_TEMPLATES["llm_chat"] = {
|
||||||
|
"default": "{{ input }}",
|
||||||
|
|
||||||
|
"py":
|
||||||
"""
|
"""
|
||||||
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
你是一个聪明的代码助手,请你给我写出简单的py代码。 \n
|
||||||
<已知信息>{{ context }}</已知信息>、
|
{{ input }}
|
||||||
<问题>{{ question }}</问题>""",
|
|
||||||
|
|
||||||
# 基于agent的提示词模板
|
|
||||||
"agent_chat":
|
|
||||||
"""
|
"""
|
||||||
Answer the following questions as best you can. You have access to the following tools:
|
,
|
||||||
|
|
||||||
{tools}
|
|
||||||
Use the following format:
|
|
||||||
|
|
||||||
Question: the input question you must answer
|
|
||||||
Thought: you should always think about what to do
|
|
||||||
Action: the action to take, should be one of [{tool_names}]
|
|
||||||
Action Input: the input to the action
|
|
||||||
Observation: the result of the action
|
|
||||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
|
||||||
Thought: I now know the final answer
|
|
||||||
Final Answer: the final answer to the original input question
|
|
||||||
|
|
||||||
Begin!
|
|
||||||
|
|
||||||
history:
|
|
||||||
{history}
|
|
||||||
|
|
||||||
Question: {input}
|
|
||||||
Thought: {agent_scratchpad}
|
|
||||||
"""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PROMPT_TEMPLATES["knowledge_base_chat"] = {
|
||||||
|
"default":
|
||||||
|
"""
|
||||||
|
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
|
||||||
|
<已知信息>{{ context }}</已知信息>、
|
||||||
|
<问题>{{ question }}</问题>
|
||||||
|
""",
|
||||||
|
"text":
|
||||||
|
"""
|
||||||
|
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
|
||||||
|
<已知信息>{{ context }}</已知信息>、
|
||||||
|
<问题>{{ question }}</问题>
|
||||||
|
""",
|
||||||
|
}
|
||||||
|
PROMPT_TEMPLATES["search_engine_chat"] = {
|
||||||
|
"default":
|
||||||
|
"""
|
||||||
|
<指令>这是我搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 </指令>
|
||||||
|
<已知信息>{{ context }}</已知信息>、
|
||||||
|
<问题>{{ question }}</问题>
|
||||||
|
""",
|
||||||
|
|
||||||
## GPT或Qwen 的Prompt
|
"search":
|
||||||
# """
|
"""
|
||||||
# Answer the following questions as best you can. You have access to the following tools:
|
<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,答案请使用中文。 </指令>
|
||||||
#
|
<已知信息>{{ context }}</已知信息>、
|
||||||
# {tools}
|
<问题>{{ question }}</问题>
|
||||||
#
|
""",
|
||||||
# Please note that the "知识库查询工具" is information about the "西交利物浦大学" ,and if a question is asked about it, you must answer with the knowledge base
|
}
|
||||||
#
|
PROMPT_TEMPLATES["agent_chat"] = {
|
||||||
# Use the following format:
|
"default":
|
||||||
#
|
"""
|
||||||
# Question: the input question you must answer
|
Answer the following questions as best you can. If it is in order, you can use some tools appropriately.You have access to the following tools:
|
||||||
# Thought: you should always think about what to do
|
|
||||||
# Action: the action to take, should be one of [{tool_names}]
|
{tools}
|
||||||
# Action Input: the input to the action
|
|
||||||
# Observation: the result of the action
|
Please note that the "知识库查询工具" is information about the "西交利物浦大学" ,and if a question is asked about it, you must answer with the knowledge base,
|
||||||
# ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
Please note that the "天气查询工具" can only be used once since Question begin.
|
||||||
# Thought: I now know the final answer
|
|
||||||
# Final Answer: the final answer to the original input question
|
Use the following format:
|
||||||
#
|
Question: the input question you must answer1
|
||||||
# Begin!
|
Thought: you should always think about what to do and what tools to use.
|
||||||
#
|
Action: the action to take, should be one of [{tool_names}]
|
||||||
# history:
|
Action Input: the input to the action
|
||||||
# {history}
|
Observation: the result of the action
|
||||||
#
|
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||||
# Question: {input}
|
Thought: I now know the final answer
|
||||||
# Thought: {agent_scratchpad}
|
Final Answer: the final answer to the original input question
|
||||||
# """
|
|
||||||
|
|
||||||
|
|
||||||
## ChatGLM-Pro的Prompt
|
Begin!
|
||||||
|
history:
|
||||||
|
{history}
|
||||||
|
Question: {input}
|
||||||
|
Thought: {agent_scratchpad}
|
||||||
|
""",
|
||||||
|
"ChatGLM":
|
||||||
|
"""
|
||||||
|
请请严格按照提供的思维方式来思考。你的知识不一定正确,所以你一定要用提供的工具来思考,并给出用户答案。
|
||||||
|
你有以下工具可以使用:
|
||||||
|
{tools}
|
||||||
|
```
|
||||||
|
Question: 用户的提问或者观察到的信息,
|
||||||
|
Thought: 你应该思考该做什么,是根据工具的结果来回答问题,还是决定使用什么工具。
|
||||||
|
Action: 需要使用的工具,应该是在[{tool_names}]中的一个。
|
||||||
|
Action Input: 传入工具的内容
|
||||||
|
Observation: 工具给出的答案(不是你生成的)
|
||||||
|
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||||
|
Thought: 通过工具给出的答案,你是否能回答Question。
|
||||||
|
Final Answer是你的答案
|
||||||
|
|
||||||
# """
|
现在,我们开始!
|
||||||
# 请请严格按照提供的思维方式来思考。你的知识不一定正确,所以你一定要用提供的工具来思考,并给出用户答案。
|
你和用户的历史记录:
|
||||||
# 你有以下工具可以使用:
|
History:
|
||||||
# {tools}
|
{history}
|
||||||
# ```
|
|
||||||
# Question: 用户的提问或者观察到的信息,
|
用户开始以提问:
|
||||||
# Thought: 你应该思考该做什么,是根据工具的结果来回答问题,还是决定使用什么工具。
|
Question: {input}
|
||||||
# Action: 需要使用的工具,应该是在[{tool_names}]中的一个。
|
Thought: {agent_scratchpad}
|
||||||
# Action Input: 传入工具的内容
|
|
||||||
# Observation: 工具给出的答案(不是你生成的)
|
""",
|
||||||
# ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
}
|
||||||
# Thought: 通过工具给出的答案,你是否能回答Question。
|
|
||||||
# Final Answer是你的答案
|
|
||||||
#
|
|
||||||
# 现在,我们开始!
|
|
||||||
# 你和用户的历史记录:
|
|
||||||
# History:
|
|
||||||
# {history}
|
|
||||||
#
|
|
||||||
# 用户开始以提问:
|
|
||||||
# Question: {input}
|
|
||||||
# Thought: {agent_scratchpad}
|
|
||||||
#
|
|
||||||
# """
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@ -1,5 +1,5 @@
|
|||||||
langchain==0.0.313
|
langchain>=0.0.314
|
||||||
langchain-experimental==0.0.30
|
langchain-experimental>=0.0.30
|
||||||
fschat[model_worker]==0.2.30
|
fschat[model_worker]==0.2.30
|
||||||
openai
|
openai
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
|
|||||||
4
server/agent/__init__.py
Normal file
4
server/agent/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .model_contain import *
|
||||||
|
from .callbacks import *
|
||||||
|
from .custom_template import *
|
||||||
|
from .tools import *
|
||||||
@ -29,6 +29,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
self.done = asyncio.Event()
|
self.done = asyncio.Event()
|
||||||
self.cur_tool = {}
|
self.cur_tool = {}
|
||||||
|
self.out = True
|
||||||
|
|
||||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||||
@ -57,6 +58,7 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
|
|
||||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||||
|
self.out = True ## 重置输出
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
status=Status.tool_finish,
|
status=Status.tool_finish,
|
||||||
output_str=output.replace("Answer:", ""),
|
output_str=output.replace("Answer:", ""),
|
||||||
@ -72,7 +74,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
if token:
|
if "Action" in token: ## 减少重复输出
|
||||||
|
before_action = token.split("Action")[0]
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.running,
|
||||||
|
llm_token=before_action + "\n",
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
self.out = False
|
||||||
|
|
||||||
|
if token and self.out:
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
status=Status.running,
|
status=Status.running,
|
||||||
llm_token=token,
|
llm_token=token,
|
||||||
@ -86,6 +98,14 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
)
|
)
|
||||||
self.queue.put_nowait(dumps(self.cur_tool))
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_chat_model_start(self,serialized: Dict[str, Any], **kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.start,
|
||||||
|
llm_token="",
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
self.cur_tool.update(
|
self.cur_tool.update(
|
||||||
status=Status.complete,
|
status=Status.complete,
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from langchain.agents import Tool, AgentOutputParser
|
from langchain.agents import Tool, AgentOutputParser
|
||||||
from langchain.prompts import StringPromptTemplate
|
from langchain.prompts import StringPromptTemplate
|
||||||
from typing import List, Union, Tuple, Dict
|
from typing import List
|
||||||
from langchain.schema import AgentAction, AgentFinish
|
from langchain.schema import AgentAction, AgentFinish
|
||||||
import re
|
from server.agent import model_container
|
||||||
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
|
||||||
|
|
||||||
begin = False
|
begin = False
|
||||||
class CustomPromptTemplate(StringPromptTemplate):
|
class CustomPromptTemplate(StringPromptTemplate):
|
||||||
# The template to use
|
# The template to use
|
||||||
@ -41,7 +39,7 @@ class CustomOutputParser(AgentOutputParser):
|
|||||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
||||||
# Check if agent should finish
|
# Check if agent should finish
|
||||||
support_agent = ["gpt","Qwen","qwen-api","baichuan-api"]
|
support_agent = ["gpt","Qwen","qwen-api","baichuan-api"]
|
||||||
if not any(agent in LLM_MODEL for agent in support_agent) and self.begin:
|
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
|
||||||
self.begin = False
|
self.begin = False
|
||||||
stop_words = ["Observation:"]
|
stop_words = ["Observation:"]
|
||||||
min_index = len(llm_output)
|
min_index = len(llm_output)
|
||||||
|
|||||||
8
server/agent/model_contain.py
Normal file
8
server/agent/model_contain.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍
|
||||||
|
class ModelContainer:
|
||||||
|
def __init__(self):
|
||||||
|
self.MODEL = None
|
||||||
|
self.DATABASE = None
|
||||||
|
|
||||||
|
model_container = ModelContainer()
|
||||||
@ -1,37 +0,0 @@
|
|||||||
## 单独运行的时候需要添加
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
|
||||||
from configs import LLM_MODEL, TEMPERATURE, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
async def search_knowledge_base_iter(query: str):
|
|
||||||
response = await knowledge_base_chat(query=query,
|
|
||||||
knowledge_base_name="tcqa",
|
|
||||||
model_name=LLM_MODEL,
|
|
||||||
temperature=TEMPERATURE,
|
|
||||||
history=[],
|
|
||||||
top_k = VECTOR_SEARCH_TOP_K,
|
|
||||||
prompt_name = "knowledge_base_chat",
|
|
||||||
score_threshold = SCORE_THRESHOLD,
|
|
||||||
stream=False)
|
|
||||||
|
|
||||||
contents = ""
|
|
||||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
|
||||||
data = json.loads(data)
|
|
||||||
contents = data["answer"]
|
|
||||||
docs = data["docs"]
|
|
||||||
return contents
|
|
||||||
|
|
||||||
def search_knowledge(query: str):
|
|
||||||
return asyncio.run(search_knowledge_base_iter(query))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
result = search_knowledge("大数据男女比例")
|
|
||||||
print("答案:",result)
|
|
||||||
11
server/agent/tools/__init__.py
Normal file
11
server/agent/tools/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
## 导入所有的工具类
|
||||||
|
from .search_knowledge_simple import knowledge_search_simple
|
||||||
|
from .search_all_knowledge_once import knowledge_search_once
|
||||||
|
from .search_all_knowledge_more import knowledge_search_more
|
||||||
|
from .travel_assistant import travel_assistant
|
||||||
|
from .calculate import calculate
|
||||||
|
from .translator import translate
|
||||||
|
from .weather import weathercheck
|
||||||
|
from .shell import shell
|
||||||
|
from .search_internet import search_internet
|
||||||
|
|
||||||
@ -1,12 +1,12 @@
|
|||||||
## 单独运行的时候需要添加
|
## 单独运行的时候需要添加
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.chains import LLMMathChain
|
from langchain.chains import LLMMathChain
|
||||||
from server.utils import get_ChatOpenAI
|
from server.agent import model_container
|
||||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
|
||||||
_PROMPT_TEMPLATE = """
|
_PROMPT_TEMPLATE = """
|
||||||
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
|
将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
|
||||||
问题: ${{包含数学问题的问题。}}
|
问题: ${{包含数学问题的问题。}}
|
||||||
@ -63,11 +63,7 @@ PROMPT = PromptTemplate(
|
|||||||
|
|
||||||
|
|
||||||
def calculate(query: str):
|
def calculate(query: str):
|
||||||
model = get_ChatOpenAI(
|
model = model_container.MODEL
|
||||||
streaming=False,
|
|
||||||
model_name=LLM_MODEL,
|
|
||||||
temperature=TEMPERATURE,
|
|
||||||
)
|
|
||||||
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
|
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||||
ans = llm_math.run(query)
|
ans = llm_math.run(query)
|
||||||
return ans
|
return ans
|
||||||
296
server/agent/tools/search_all_knowledge_more.py
Normal file
296
server/agent/tools/search_all_knowledge_more.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
## 单独运行的时候需要添加
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
from typing import Dict
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.pydantic_v1 import Extra, root_validator
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
from typing import List, Any, Optional
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
|
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
||||||
|
import asyncio
|
||||||
|
from server.agent import model_container
|
||||||
|
|
||||||
|
|
||||||
|
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||||
|
response = await knowledge_base_chat(query=query,
|
||||||
|
knowledge_base_name=database,
|
||||||
|
model_name=model_container.MODEL.model_name,
|
||||||
|
temperature=0.01,
|
||||||
|
history=[],
|
||||||
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
|
max_tokens=None,
|
||||||
|
prompt_name="default",
|
||||||
|
score_threshold=SCORE_THRESHOLD,
|
||||||
|
stream=False)
|
||||||
|
|
||||||
|
contents = ""
|
||||||
|
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||||
|
data = json.loads(data)
|
||||||
|
contents += data["answer"]
|
||||||
|
docs = data["docs"]
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
async def search_knowledge_multiple(queries) -> List[str]:
|
||||||
|
# queries 应该是一个包含多个 (database, query) 元组的列表
|
||||||
|
tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
|
||||||
|
combined_results = []
|
||||||
|
for (database, _), result in zip(queries, results):
|
||||||
|
message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
|
||||||
|
combined_results.append(message)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
|
||||||
|
def search_knowledge(queries) -> str:
|
||||||
|
responses = asyncio.run(search_knowledge_multiple(queries))
|
||||||
|
# 输出每个整合的查询结果
|
||||||
|
contents = ""
|
||||||
|
for response in responses:
|
||||||
|
contents += response + "\n\n"
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
_PROMPT_TEMPLATE = """
|
||||||
|
用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
|
||||||
|
|
||||||
|
对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
|
||||||
|
|
||||||
|
例子:
|
||||||
|
|
||||||
|
robotic,机器人男女比例是多少
|
||||||
|
bigdata,大数据的就业情况如何
|
||||||
|
|
||||||
|
|
||||||
|
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
|
||||||
|
|
||||||
|
{database_names}
|
||||||
|
|
||||||
|
你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||||
|
|
||||||
|
|
||||||
|
Question: ${{用户的问题}}
|
||||||
|
|
||||||
|
```text
|
||||||
|
${{知识库名称,查询问题,不要带有任何除了,之外的符号}}
|
||||||
|
|
||||||
|
```output
|
||||||
|
数据库查询的结果
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
这是一个完整的问题拆分和提问的例子:
|
||||||
|
|
||||||
|
|
||||||
|
问题: 分别对比机器人和大数据专业的就业情况并告诉我哪儿专业的就业情况更好?
|
||||||
|
|
||||||
|
```text
|
||||||
|
robotic,机器人专业的就业情况
|
||||||
|
bigdata,大数据专业的就业情况
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
现在,我们开始作答
|
||||||
|
问题: {question}
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROMPT = PromptTemplate(
|
||||||
|
input_variables=["question", "database_names"],
|
||||||
|
template=_PROMPT_TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMKnowledgeChain(LLMChain):
|
||||||
|
llm_chain: LLMChain
|
||||||
|
llm: Optional[BaseLanguageModel] = None
|
||||||
|
"""[Deprecated] LLM wrapper to use."""
|
||||||
|
prompt: BasePromptTemplate = PROMPT
|
||||||
|
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||||
|
database_names: Dict[str, str] = None
|
||||||
|
input_key: str = "question" #: :meta private:
|
||||||
|
output_key: str = "answer" #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
|
if "llm" in values:
|
||||||
|
warnings.warn(
|
||||||
|
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||||
|
"Please instantiate with llm_chain argument or using the from_llm "
|
||||||
|
"class method."
|
||||||
|
)
|
||||||
|
if "llm_chain" not in values and values["llm"] is not None:
|
||||||
|
prompt = values.get("prompt", PROMPT)
|
||||||
|
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Expect output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _evaluate_expression(self, queries) -> str:
|
||||||
|
try:
|
||||||
|
output = search_knowledge(queries)
|
||||||
|
except Exception as e:
|
||||||
|
output = "输入的信息有误或不存在知识库,错误信息如下:\n"
|
||||||
|
return output + str(e)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _process_llm_result(
|
||||||
|
self,
|
||||||
|
llm_output: str,
|
||||||
|
run_manager: CallbackManagerForChainRun
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
|
||||||
|
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||||
|
|
||||||
|
llm_output = llm_output.strip()
|
||||||
|
# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||||
|
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||||
|
if text_match:
|
||||||
|
expression = text_match.group(1).strip()
|
||||||
|
cleaned_input_str = (expression.replace("\"", "").replace("“", "").
|
||||||
|
replace("”", "").replace("```", "").strip())
|
||||||
|
lines = cleaned_input_str.split("\n")
|
||||||
|
# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
|
||||||
|
|
||||||
|
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||||
|
run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
|
||||||
|
output = self._evaluate_expression(queries)
|
||||||
|
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||||
|
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||||
|
answer = "Answer: " + output
|
||||||
|
elif llm_output.startswith("Answer:"):
|
||||||
|
answer = llm_output
|
||||||
|
elif "Answer:" in llm_output:
|
||||||
|
answer = llm_output.split("Answer:")[-1]
|
||||||
|
else:
|
||||||
|
return {self.output_key: f"输入的格式不对:\n {llm_output}"}
|
||||||
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
async def _aprocess_llm_result(
|
||||||
|
self,
|
||||||
|
llm_output: str,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||||
|
llm_output = llm_output.strip()
|
||||||
|
text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
|
||||||
|
if text_match:
|
||||||
|
|
||||||
|
expression = text_match.group(1).strip()
|
||||||
|
cleaned_input_str = (
|
||||||
|
expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip())
|
||||||
|
lines = cleaned_input_str.split("\n")
|
||||||
|
queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
|
||||||
|
await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
|
||||||
|
verbose=self.verbose)
|
||||||
|
|
||||||
|
output = self._evaluate_expression(queries)
|
||||||
|
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||||
|
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||||
|
answer = "Answer: " + output
|
||||||
|
elif llm_output.startswith("Answer:"):
|
||||||
|
answer = llm_output
|
||||||
|
elif "Answer:" in llm_output:
|
||||||
|
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||||
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, str],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
_run_manager.on_text(inputs[self.input_key])
|
||||||
|
self.database_names = model_container.DATABASE
|
||||||
|
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||||
|
llm_output = self.llm_chain.predict(
|
||||||
|
database_names=data_formatted_str,
|
||||||
|
question=inputs[self.input_key],
|
||||||
|
stop=["```output"],
|
||||||
|
callbacks=_run_manager.get_child(),
|
||||||
|
)
|
||||||
|
return self._process_llm_result(llm_output, _run_manager)
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, str],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
|
await _run_manager.on_text(inputs[self.input_key])
|
||||||
|
self.database_names = model_container.DATABASE
|
||||||
|
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||||
|
llm_output = await self.llm_chain.apredict(
|
||||||
|
database_names=data_formatted_str,
|
||||||
|
question=inputs[self.input_key],
|
||||||
|
stop=["```output"],
|
||||||
|
callbacks=_run_manager.get_child(),
|
||||||
|
)
|
||||||
|
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "llm_knowledge_chain"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
prompt: BasePromptTemplate = PROMPT,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
return cls(llm_chain=llm_chain, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def knowledge_search_more(query: str):
|
||||||
|
model = model_container.MODEL
|
||||||
|
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||||
|
ans = llm_knowledge.run(query)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
result = knowledge_search_more("机器人和大数据在代码教学上有什么区别")
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# 这是一个正常的切割
|
||||||
|
# queries = [
|
||||||
|
# ("bigdata", "大数据专业的男女比例"),
|
||||||
|
# ("robotic", "机器人专业的优势")
|
||||||
|
# ]
|
||||||
|
# result = search_knowledge(queries)
|
||||||
|
# print(result)
|
||||||
234
server/agent/tools/search_all_knowledge_once.py
Normal file
234
server/agent/tools/search_all_knowledge_once.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
## 单独运行的时候需要添加
|
||||||
|
# import sys
|
||||||
|
# import os
|
||||||
|
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||||
|
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.pydantic_v1 import Extra, root_validator
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
from typing import List, Any, Optional
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
|
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from server.agent import model_container
|
||||||
|
|
||||||
|
|
||||||
|
async def search_knowledge_base_iter(database: str, query: str):
|
||||||
|
response = await knowledge_base_chat(query=query,
|
||||||
|
knowledge_base_name=database,
|
||||||
|
model_name=model_container.MODEL.model_name,
|
||||||
|
temperature=0.01,
|
||||||
|
history=[],
|
||||||
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
|
max_tokens=None,
|
||||||
|
prompt_name="knowledge_base_chat",
|
||||||
|
score_threshold=SCORE_THRESHOLD,
|
||||||
|
stream=False)
|
||||||
|
|
||||||
|
contents = ""
|
||||||
|
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||||
|
data = json.loads(data)
|
||||||
|
contents += data["answer"]
|
||||||
|
docs = data["docs"]
|
||||||
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
_PROMPT_TEMPLATE = """
|
||||||
|
用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
|
||||||
|
Question: ${{用户的问题}}
|
||||||
|
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
|
||||||
|
|
||||||
|
{database_names}
|
||||||
|
|
||||||
|
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
|
||||||
|
```text
|
||||||
|
${{知识库的名称}}
|
||||||
|
```
|
||||||
|
```output
|
||||||
|
数据库查询的结果
|
||||||
|
```
|
||||||
|
答案: ${{答案}}
|
||||||
|
|
||||||
|
现在,这是我的问题:
|
||||||
|
问题: {question}
|
||||||
|
|
||||||
|
"""
|
||||||
|
PROMPT = PromptTemplate(
|
||||||
|
input_variables=["question", "database_names"],
|
||||||
|
template=_PROMPT_TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMKnowledgeChain(LLMChain):
|
||||||
|
llm_chain: LLMChain
|
||||||
|
llm: Optional[BaseLanguageModel] = None
|
||||||
|
"""[Deprecated] LLM wrapper to use."""
|
||||||
|
prompt: BasePromptTemplate = PROMPT
|
||||||
|
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
||||||
|
database_names: Dict[str, str] = model_container.DATABASE
|
||||||
|
input_key: str = "question" #: :meta private:
|
||||||
|
output_key: str = "answer" #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
|
if "llm" in values:
|
||||||
|
warnings.warn(
|
||||||
|
"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
|
||||||
|
"Please instantiate with llm_chain argument or using the from_llm "
|
||||||
|
"class method."
|
||||||
|
)
|
||||||
|
if "llm_chain" not in values and values["llm"] is not None:
|
||||||
|
prompt = values.get("prompt", PROMPT)
|
||||||
|
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Expect output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _evaluate_expression(self, dataset, query) -> str:
|
||||||
|
try:
|
||||||
|
output = asyncio.run(search_knowledge_base_iter(dataset, query))
|
||||||
|
except Exception as e:
|
||||||
|
output = "输入的信息有误或不存在知识库"
|
||||||
|
return output
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _process_llm_result(
|
||||||
|
self,
|
||||||
|
llm_output: str,
|
||||||
|
llm_input: str,
|
||||||
|
run_manager: CallbackManagerForChainRun
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
|
||||||
|
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||||
|
|
||||||
|
llm_output = llm_output.strip()
|
||||||
|
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||||
|
if text_match:
|
||||||
|
database = text_match.group(1).strip()
|
||||||
|
output = self._evaluate_expression(database, llm_input)
|
||||||
|
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||||
|
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||||
|
answer = "Answer: " + output
|
||||||
|
elif llm_output.startswith("Answer:"):
|
||||||
|
answer = llm_output
|
||||||
|
elif "Answer:" in llm_output:
|
||||||
|
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||||
|
else:
|
||||||
|
return {self.output_key: f"输入的格式不对: {llm_output}"}
|
||||||
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
async def _aprocess_llm_result(
|
||||||
|
self,
|
||||||
|
llm_output: str,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||||
|
llm_output = llm_output.strip()
|
||||||
|
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||||
|
if text_match:
|
||||||
|
expression = text_match.group(1)
|
||||||
|
output = self._evaluate_expression(expression)
|
||||||
|
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||||
|
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||||
|
answer = "Answer: " + output
|
||||||
|
elif llm_output.startswith("Answer:"):
|
||||||
|
answer = llm_output
|
||||||
|
elif "Answer:" in llm_output:
|
||||||
|
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown format from LLM: {llm_output}")
|
||||||
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, str],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
_run_manager.on_text(inputs[self.input_key])
|
||||||
|
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||||
|
llm_output = self.llm_chain.predict(
|
||||||
|
database_names=data_formatted_str,
|
||||||
|
question=inputs[self.input_key],
|
||||||
|
stop=["```output"],
|
||||||
|
callbacks=_run_manager.get_child(),
|
||||||
|
)
|
||||||
|
return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, str],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
|
await _run_manager.on_text(inputs[self.input_key])
|
||||||
|
data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
|
||||||
|
llm_output = await self.llm_chain.apredict(
|
||||||
|
database_names=data_formatted_str,
|
||||||
|
question=inputs[self.input_key],
|
||||||
|
stop=["```output"],
|
||||||
|
callbacks=_run_manager.get_child(),
|
||||||
|
)
|
||||||
|
return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "llm_knowledge_chain"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
prompt: BasePromptTemplate = PROMPT,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
return cls(llm_chain=llm_chain, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def knowledge_search_once(query: str):
|
||||||
|
model = model_container.MODEL
|
||||||
|
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||||
|
ans = llm_knowledge.run(query)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
result = knowledge_search_once("大数据的男女比例")
|
||||||
|
print(result)
|
||||||
@ -1,35 +1,39 @@
|
|||||||
## 单独运行的时候需要添加
|
## 单独运行的时候需要添加
|
||||||
import sys
|
# import sys
|
||||||
import os
|
# import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from server.chat import search_engine_chat
|
from server.chat import search_engine_chat
|
||||||
from configs import LLM_MODEL, TEMPERATURE, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
from configs import VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from server.agent import model_container
|
||||||
|
|
||||||
async def search_engine_iter(query: str):
|
async def search_engine_iter(query: str):
|
||||||
response = await search_engine_chat(query=query,
|
response = await search_engine_chat(query=query,
|
||||||
search_engine_name="bing",
|
search_engine_name="bing", # 这里切换搜索引擎
|
||||||
model_name=LLM_MODEL,
|
model_name=model_container.MODEL.model_name,
|
||||||
temperature=TEMPERATURE,
|
temperature=0.01, # Agent 搜索互联网的时候,温度设置为0.01
|
||||||
history=[],
|
history=[],
|
||||||
top_k = VECTOR_SEARCH_TOP_K,
|
top_k = VECTOR_SEARCH_TOP_K,
|
||||||
prompt_name = "knowledge_base_chat",
|
max_tokens= None, # Agent 搜索互联网的时候,max_tokens设置为None
|
||||||
|
prompt_name = "default",
|
||||||
stream=False)
|
stream=False)
|
||||||
|
|
||||||
contents = ""
|
contents = ""
|
||||||
|
|
||||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
contents = data["answer"]
|
contents = data["answer"]
|
||||||
docs = data["docs"]
|
docs = data["docs"]
|
||||||
|
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
def search_internet(query: str):
|
def search_internet(query: str):
|
||||||
|
|
||||||
return asyncio.run(search_engine_iter(query))
|
return asyncio.run(search_engine_iter(query))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
result = search_internet("大数据男女比例")
|
result = search_internet("今天星期几")
|
||||||
print("答案:",result)
|
print("答案:",result)
|
||||||
38
server/agent/tools/search_knowledge_simple.py
Normal file
38
server/agent/tools/search_knowledge_simple.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
## 最简单的版本,只支持固定的知识库
|
||||||
|
|
||||||
|
# ## 单独运行的时候需要添加
|
||||||
|
# import sys
|
||||||
|
# import os
|
||||||
|
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||||
|
|
||||||
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
|
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from server.agent import model_container
|
||||||
|
|
||||||
|
async def search_knowledge_base_iter(database: str, query: str) -> str:
|
||||||
|
response = await knowledge_base_chat(query=query,
|
||||||
|
knowledge_base_name=database,
|
||||||
|
model_name=model_container.MODEL.model_name,
|
||||||
|
temperature=0.01,
|
||||||
|
history=[],
|
||||||
|
top_k=VECTOR_SEARCH_TOP_K,
|
||||||
|
prompt_name="knowledge_base_chat",
|
||||||
|
score_threshold=SCORE_THRESHOLD,
|
||||||
|
stream=False)
|
||||||
|
|
||||||
|
contents = ""
|
||||||
|
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||||
|
data = json.loads(data)
|
||||||
|
contents = data["answer"]
|
||||||
|
docs = data["docs"]
|
||||||
|
return contents
|
||||||
|
|
||||||
|
def knowledge_search_simple(query: str):
|
||||||
|
return asyncio.run(search_knowledge_base_iter(query))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
result = knowledge_search_simple("大数据男女比例")
|
||||||
|
print("答案:",result)
|
||||||
@ -1,12 +1,11 @@
|
|||||||
## 单独运行的时候需要添加
|
## 单独运行的时候需要添加
|
||||||
import sys
|
# import sys
|
||||||
import os
|
# import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from server.utils import get_ChatOpenAI
|
from server.agent import model_container
|
||||||
from langchain.chains.llm_math.prompt import PROMPT
|
|
||||||
from configs.model_config import LLM_MODEL,TEMPERATURE
|
|
||||||
|
|
||||||
_PROMPT_TEMPLATE = '''
|
_PROMPT_TEMPLATE = '''
|
||||||
# 指令
|
# 指令
|
||||||
@ -30,11 +29,7 @@ PROMPT = PromptTemplate(
|
|||||||
|
|
||||||
|
|
||||||
def translate(query: str):
|
def translate(query: str):
|
||||||
model = get_ChatOpenAI(
|
model = model_container.MODEL
|
||||||
streaming=False,
|
|
||||||
model_name=LLM_MODEL,
|
|
||||||
temperature=TEMPERATURE,
|
|
||||||
)
|
|
||||||
llm_translate = LLMChain(llm=model, prompt=PROMPT)
|
llm_translate = LLMChain(llm=model, prompt=PROMPT)
|
||||||
ans = llm_translate.run(query)
|
ans = llm_translate.run(query)
|
||||||
return ans
|
return ans
|
||||||
@ -1,6 +1,5 @@
|
|||||||
## 使用和风天气API查询天气,这个模型仅仅对免费的API进行了适配
|
## 使用和风天气API查询天气,这个模型仅仅对免费的API进行了适配
|
||||||
## 这个模型的提示词非常复杂,我们推荐使用GPT4模型进行运行
|
## 这个模型的提示词非常复杂,我们推荐使用GPT4模型进行运行
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
## 单独运行的时候需要添加
|
## 单独运行的时候需要添加
|
||||||
@ -8,10 +7,6 @@ import sys
|
|||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
|
||||||
|
|
||||||
from server.utils import get_ChatOpenAI
|
|
||||||
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
@ -27,10 +22,9 @@ from langchain.schema import BasePromptTemplate
|
|||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
import requests
|
import requests
|
||||||
from typing import List, Any, Optional
|
from typing import List, Any, Optional
|
||||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
from server.agent import model_container
|
||||||
|
|
||||||
## 使用和风天气API查询天气
|
## 使用和风天气API查询天气
|
||||||
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
|
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
|
||||||
@ -237,9 +231,6 @@ class LLMWeatherChain(Chain):
|
|||||||
output = weather(expression)
|
output = weather(expression)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
output = "输入的信息有误,请再次尝试"
|
output = "输入的信息有误,请再次尝试"
|
||||||
return {self.output_key: output}
|
|
||||||
raise ValueError(f"错误: {expression},输入的信息不对")
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _process_llm_result(
|
def _process_llm_result(
|
||||||
@ -262,7 +253,6 @@ class LLMWeatherChain(Chain):
|
|||||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
||||||
else:
|
else:
|
||||||
return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"}
|
return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"}
|
||||||
# raise ValueError(f"unknown format from LLM: {llm_output}")
|
|
||||||
return {self.output_key: answer}
|
return {self.output_key: answer}
|
||||||
|
|
||||||
async def _aprocess_llm_result(
|
async def _aprocess_llm_result(
|
||||||
@ -273,6 +263,7 @@ class LLMWeatherChain(Chain):
|
|||||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||||
llm_output = llm_output.strip()
|
llm_output = llm_output.strip()
|
||||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||||
|
|
||||||
if text_match:
|
if text_match:
|
||||||
expression = text_match.group(1)
|
expression = text_match.group(1)
|
||||||
output = self._evaluate_expression(expression)
|
output = self._evaluate_expression(expression)
|
||||||
@ -332,14 +323,10 @@ class LLMWeatherChain(Chain):
|
|||||||
|
|
||||||
|
|
||||||
def weathercheck(query: str):
|
def weathercheck(query: str):
|
||||||
model = get_ChatOpenAI(
|
model = model_container.MODEL
|
||||||
streaming=False,
|
|
||||||
model_name=LLM_MODEL,
|
|
||||||
temperature=TEMPERATURE,
|
|
||||||
)
|
|
||||||
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
|
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
|
||||||
ans = llm_weather.run(query)
|
ans = llm_weather.run(query)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
result = weathercheck("苏州工姑苏区今晚热不热?")
|
result = weathercheck("苏州姑苏区今晚热不热?")
|
||||||
@ -1,16 +1,5 @@
|
|||||||
import sys
|
from langchain.tools import Tool
|
||||||
import os
|
from server.agent.tools import *
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
||||||
|
|
||||||
from server.agent.math import calculate
|
|
||||||
from server.agent.translator import translate
|
|
||||||
from server.agent.weather import weathercheck
|
|
||||||
from server.agent.shell import shell
|
|
||||||
from langchain.agents import Tool
|
|
||||||
from server.agent.search_knowledge import search_knowledge
|
|
||||||
from server.agent.search_internet import search_internet
|
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=calculate,
|
func=calculate,
|
||||||
@ -25,7 +14,7 @@ tools = [
|
|||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=weathercheck,
|
func=weathercheck,
|
||||||
name="天气查询工具",
|
name="天气查询工具",
|
||||||
description="如果你无法访问互联网,并需要查询中国各地未来24小时的天气,你应该使用这个工具,每轮对话仅能使用一次",
|
description="无需访问互联网,使用这个工具查询中国各地未来24小时的天气",
|
||||||
),
|
),
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=shell,
|
func=shell,
|
||||||
@ -33,15 +22,15 @@ tools = [
|
|||||||
description="使用命令行工具输出",
|
description="使用命令行工具输出",
|
||||||
),
|
),
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=search_knowledge,
|
func=knowledge_search_more,
|
||||||
name="知识库查询工具",
|
name="知识库查询工具",
|
||||||
description="访问知识库来获取答案",
|
description="优先访问知识库来获取答案",
|
||||||
),
|
),
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=search_internet,
|
func=search_internet,
|
||||||
name="互联网查询工具",
|
name="互联网查询工具",
|
||||||
description="如果你无法访问互联网,这个工具可以帮助你访问Bing互联网来解答问题",
|
description="如果你无法访问互联网,这个工具可以帮助你访问Bing互联网来解答问题",
|
||||||
),
|
),
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
tool_names = [tool.name for tool in tools]
|
tool_names = [tool.name for tool in tools]
|
||||||
@ -16,7 +16,7 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
|
|||||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
update_docs, download_doc, recreate_vector_store,
|
update_docs, download_doc, recreate_vector_store,
|
||||||
search_docs, DocumentWithScore)
|
search_docs, DocumentWithScore, update_info)
|
||||||
from server.llm_api import (list_running_models, list_config_models,
|
from server.llm_api import (list_running_models, list_config_models,
|
||||||
change_llm_model, stop_llm_model,
|
change_llm_model, stop_llm_model,
|
||||||
get_model_config, list_search_engines)
|
get_model_config, list_search_engines)
|
||||||
@ -115,6 +115,11 @@ def create_app():
|
|||||||
summary="删除知识库内指定文件"
|
summary="删除知识库内指定文件"
|
||||||
)(delete_docs)
|
)(delete_docs)
|
||||||
|
|
||||||
|
app.post("/knowledge_base/update_info",
|
||||||
|
tags=["Knowledge Base Management"],
|
||||||
|
response_model=BaseResponse,
|
||||||
|
summary="更新知识库介绍"
|
||||||
|
)(update_info)
|
||||||
app.post("/knowledge_base/update_docs",
|
app.post("/knowledge_base/update_docs",
|
||||||
tags=["Knowledge Base Management"],
|
tags=["Knowledge Base Management"],
|
||||||
response_model=BaseResponse,
|
response_model=BaseResponse,
|
||||||
|
|||||||
@ -1,33 +1,34 @@
|
|||||||
from langchain.memory import ConversationBufferWindowMemory
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
from server.agent.tools import tools, tool_names
|
from server.agent.tools_select import tools, tool_names
|
||||||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status, dumps
|
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
||||||
from langchain.agents import AgentExecutor, LLMSingleActionAgent
|
from langchain.agents import AgentExecutor, LLMSingleActionAgent
|
||||||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
||||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from typing import AsyncIterable, Optional
|
from typing import AsyncIterable, Optional, Dict
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
from typing import List
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
import json
|
import json
|
||||||
|
from server.agent import model_container
|
||||||
|
from server.knowledge_base.kb_service.base import get_kb_details
|
||||||
|
|
||||||
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
examples=[[
|
examples=[[
|
||||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
{"role": "user", "content": "请使用知识库工具查询今天北京天气"},
|
||||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
{"role": "assistant", "content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]]
|
||||||
),
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
|
||||||
prompt_name: str = Body("agent_chat",
|
# TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||||
):
|
):
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
@ -43,19 +44,26 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_template = CustomPromptTemplate(
|
## 传入全局变量来实现agent调用
|
||||||
template=get_prompt_template(prompt_name),
|
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
||||||
|
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
||||||
|
model_container.MODEL = model
|
||||||
|
|
||||||
|
prompt_template = get_prompt_template("agent_chat", prompt_name)
|
||||||
|
prompt_template_agent = CustomPromptTemplate(
|
||||||
|
template=prompt_template,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
input_variables=["input", "intermediate_steps", "history"]
|
input_variables=["input", "intermediate_steps", "history"]
|
||||||
)
|
)
|
||||||
output_parser = CustomOutputParser()
|
output_parser = CustomOutputParser()
|
||||||
llm_chain = LLMChain(llm=model, prompt=prompt_template)
|
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
|
||||||
agent = LLMSingleActionAgent(
|
agent = LLMSingleActionAgent(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
stop=["\nObservation:", "Observation:", "<|im_end|>"], # Qwen模型中使用这个
|
stop=["\nObservation:", "Observation:", "<|im_end|>"], # Qwen模型中使用这个
|
||||||
allowed_tools=tool_names,
|
allowed_tools=tool_names,
|
||||||
)
|
)
|
||||||
# 把history转成agent的memory
|
# 把history转成agent的memory
|
||||||
@ -73,15 +81,15 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
verbose=True,
|
verbose=True,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
)
|
)
|
||||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
task = asyncio.create_task(wrap_done(
|
task = asyncio.create_task(wrap_done(
|
||||||
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
|
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
|
||||||
callback.done))
|
callback.done))
|
||||||
break
|
break
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
tools_use = []
|
tools_use = []
|
||||||
@ -89,46 +97,55 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||||
continue
|
continue
|
||||||
if data["status"] == Status.error:
|
elif data["status"] == Status.error:
|
||||||
|
tools_use.append("\n```\n")
|
||||||
tools_use.append("工具名称: " + data["tool_name"])
|
tools_use.append("工具名称: " + data["tool_name"])
|
||||||
tools_use.append("工具状态: " + "调用失败")
|
tools_use.append("工具状态: " + "调用失败")
|
||||||
tools_use.append("错误信息: " + data["error"])
|
tools_use.append("错误信息: " + data["error"])
|
||||||
tools_use.append("重新开始尝试")
|
tools_use.append("重新开始尝试")
|
||||||
tools_use.append("\n```\n")
|
tools_use.append("\n```\n")
|
||||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
||||||
if data["status"] == Status.agent_action:
|
elif data["status"] == Status.tool_finish:
|
||||||
yield json.dumps({"answer": "\n\n```\n\n"}, ensure_ascii=False)
|
tools_use.append("\n```\n")
|
||||||
if data["status"] == Status.tool_finish:
|
|
||||||
tools_use.append("工具名称: " + data["tool_name"])
|
tools_use.append("工具名称: " + data["tool_name"])
|
||||||
tools_use.append("工具状态: " + "调用成功")
|
tools_use.append("工具状态: " + "调用成功")
|
||||||
tools_use.append("工具输入: " + data["input_str"])
|
tools_use.append("工具输入: " + data["input_str"])
|
||||||
tools_use.append("工具输出: " + data["output_str"])
|
tools_use.append("工具输出: " + data["output_str"])
|
||||||
tools_use.append("\n```\n")
|
tools_use.append("\n```\n")
|
||||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
||||||
if data["status"] == Status.agent_finish:
|
elif data["status"] == Status.agent_finish:
|
||||||
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
|
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
|
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
|
||||||
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
# agent必须要steram=True,这部分暂时没有完成
|
|
||||||
# result = []
|
|
||||||
# async for chunk in callback.aiter():
|
|
||||||
# data = json.loads(chunk)
|
|
||||||
# status = data["status"]
|
|
||||||
# if status == Status.start:
|
|
||||||
# result.append(chunk)
|
|
||||||
# elif status == Status.running:
|
|
||||||
# result[-1]["llm_token"] += chunk["llm_token"]
|
|
||||||
# elif status == Status.complete:
|
|
||||||
# result[-1]["status"] = Status.complete
|
|
||||||
# elif status == Status.agent_finish:
|
|
||||||
# result.append(chunk)
|
|
||||||
# elif status == Status.agent_finish:
|
|
||||||
# pass
|
|
||||||
# yield dumps(result)
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
answer = ""
|
||||||
|
final_answer = ""
|
||||||
|
async for chunk in callback.aiter():
|
||||||
|
# Use server-sent-events to stream the response
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||||
|
continue
|
||||||
|
if data["status"] == Status.error:
|
||||||
|
answer += "\n```\n"
|
||||||
|
answer += "工具名称: " + data["tool_name"] + "\n"
|
||||||
|
answer += "工具状态: " + "调用失败" + "\n"
|
||||||
|
answer += "错误信息: " + data["error"] + "\n"
|
||||||
|
answer += "\n```\n"
|
||||||
|
if data["status"] == Status.tool_finish:
|
||||||
|
answer += "\n```\n"
|
||||||
|
answer += "工具名称: " + data["tool_name"] + "\n"
|
||||||
|
answer += "工具状态: " + "调用成功" + "\n"
|
||||||
|
answer += "工具输入: " + data["input_str"] + "\n"
|
||||||
|
answer += "工具输出: " + data["output_str"] + "\n"
|
||||||
|
answer += "\n```\n"
|
||||||
|
if data["status"] == Status.agent_finish:
|
||||||
|
final_answer = data["final_answer"]
|
||||||
|
else:
|
||||||
|
answer += data["llm_token"]
|
||||||
|
|
||||||
|
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(agent_chat_iterator(query=query,
|
return StreamingResponse(agent_chat_iterator(query=query,
|
||||||
|
|||||||
@ -22,9 +22,10 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
|
||||||
|
# TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||||
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
@ -41,7 +42,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_template = get_prompt_template(prompt_name)
|
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
|
|||||||
@ -31,9 +31,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
|
||||||
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
# TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||||||
request: Request = None,
|
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
@ -57,7 +57,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
prompt_template = get_prompt_template(prompt_name)
|
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
@ -74,10 +74,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||||||
for inum, doc in enumerate(docs):
|
for inum, doc in enumerate(docs):
|
||||||
filename = os.path.split(doc.metadata["source"])[-1]
|
filename = os.path.split(doc.metadata["source"])[-1]
|
||||||
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename})
|
||||||
url = f"{request.base_url}knowledge_base/download_doc?" + parameters
|
url = f"/knowledge_base/download_doc?" + parameters
|
||||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||||
source_documents.append(text)
|
source_documents.append(text)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
async for token in callback.aiter():
|
async for token in callback.aiter():
|
||||||
# Use server-sent-events to stream the response
|
# Use server-sent-events to stream the response
|
||||||
|
|||||||
@ -72,8 +72,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"),
|
||||||
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
# TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||||||
|
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
):
|
):
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
@ -101,7 +102,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||||||
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
prompt_template = get_prompt_template(prompt_name)
|
prompt_template = get_prompt_template("search_engine_chat", prompt_name)
|
||||||
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
|
|||||||
@ -10,10 +10,11 @@ class KnowledgeBaseModel(Base):
|
|||||||
__tablename__ = 'knowledge_base'
|
__tablename__ = 'knowledge_base'
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID')
|
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID')
|
||||||
kb_name = Column(String(50), comment='知识库名称')
|
kb_name = Column(String(50), comment='知识库名称')
|
||||||
|
kb_info = Column(String(200), comment='知识库简介(用于Agent)')
|
||||||
vs_type = Column(String(50), comment='向量库类型')
|
vs_type = Column(String(50), comment='向量库类型')
|
||||||
embed_model = Column(String(50), comment='嵌入模型名称')
|
embed_model = Column(String(50), comment='嵌入模型名称')
|
||||||
file_count = Column(Integer, default=0, comment='文件数量')
|
file_count = Column(Integer, default=0, comment='文件数量')
|
||||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}', vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>"
|
return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}',kb_intro='{self.kb_info} vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>"
|
||||||
|
|||||||
@ -3,13 +3,14 @@ from server.db.session import with_session
|
|||||||
|
|
||||||
|
|
||||||
@with_session
|
@with_session
|
||||||
def add_kb_to_db(session, kb_name, vs_type, embed_model):
|
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
|
||||||
# 创建知识库实例
|
# 创建知识库实例
|
||||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
|
||||||
if not kb:
|
if not kb:
|
||||||
kb = KnowledgeBaseModel(kb_name=kb_name, vs_type=vs_type, embed_model=embed_model)
|
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model)
|
||||||
session.add(kb)
|
session.add(kb)
|
||||||
else: # update kb with new vs_type and embed_model
|
else: # update kb with new vs_type and embed_model
|
||||||
|
kb.kb_info = kb_info
|
||||||
kb.vs_type = vs_type
|
kb.vs_type = vs_type
|
||||||
kb.embed_model = embed_model
|
kb.embed_model = embed_model
|
||||||
return True
|
return True
|
||||||
@ -53,6 +54,7 @@ def get_kb_detail(session, kb_name: str) -> dict:
|
|||||||
if kb:
|
if kb:
|
||||||
return {
|
return {
|
||||||
"kb_name": kb.kb_name,
|
"kb_name": kb.kb_name,
|
||||||
|
"kb_info": kb.kb_info,
|
||||||
"vs_type": kb.vs_type,
|
"vs_type": kb.vs_type,
|
||||||
"embed_model": kb.embed_model,
|
"embed_model": kb.embed_model,
|
||||||
"file_count": kb.file_count,
|
"file_count": kb.file_count,
|
||||||
|
|||||||
@ -203,6 +203,20 @@ def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||||||
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
||||||
|
|
||||||
|
|
||||||
|
def update_info(knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
|
kb_info:str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
|
||||||
|
):
|
||||||
|
if not validate_kb_name(knowledge_base_name):
|
||||||
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|
||||||
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
|
if kb is None:
|
||||||
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
kb.update_info(kb_info)
|
||||||
|
|
||||||
|
return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
|
||||||
|
|
||||||
|
|
||||||
def update_docs(
|
def update_docs(
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
|
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from server.db.repository.knowledge_file_repository import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
EMBEDDING_MODEL)
|
EMBEDDING_MODEL, KB_INFO)
|
||||||
from server.knowledge_base.utils import (
|
from server.knowledge_base.utils import (
|
||||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||||
list_kbs_from_folder, list_files_from_folder,
|
list_kbs_from_folder, list_files_from_folder,
|
||||||
@ -42,11 +42,11 @@ class KBService(ABC):
|
|||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
):
|
):
|
||||||
self.kb_name = knowledge_base_name
|
self.kb_name = knowledge_base_name
|
||||||
|
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||||
self.embed_model = embed_model
|
self.embed_model = embed_model
|
||||||
self.kb_path = get_kb_path(self.kb_name)
|
self.kb_path = get_kb_path(self.kb_name)
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
self.do_init()
|
self.do_init()
|
||||||
|
|
||||||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||||
return load_embeddings(self.embed_model, embed_device)
|
return load_embeddings(self.embed_model, embed_device)
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ class KBService(ABC):
|
|||||||
if not os.path.exists(self.doc_path):
|
if not os.path.exists(self.doc_path):
|
||||||
os.makedirs(self.doc_path)
|
os.makedirs(self.doc_path)
|
||||||
self.do_create_kb()
|
self.do_create_kb()
|
||||||
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
|
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def clear_vs(self):
|
def clear_vs(self):
|
||||||
@ -116,6 +116,14 @@ class KBService(ABC):
|
|||||||
os.remove(kb_file.filepath)
|
os.remove(kb_file.filepath)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
def update_info(self, kb_info: str):
|
||||||
|
"""
|
||||||
|
更新知识库介绍
|
||||||
|
"""
|
||||||
|
self.kb_info = kb_info
|
||||||
|
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||||
|
return status
|
||||||
|
|
||||||
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||||
"""
|
"""
|
||||||
使用content中的文件更新向量库
|
使用content中的文件更新向量库
|
||||||
@ -127,7 +135,7 @@ class KBService(ABC):
|
|||||||
|
|
||||||
def exist_doc(self, file_name: str):
|
def exist_doc(self, file_name: str):
|
||||||
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
|
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||||
filename=file_name))
|
filename=file_name))
|
||||||
|
|
||||||
def list_files(self):
|
def list_files(self):
|
||||||
return list_files_from_db(self.kb_name)
|
return list_files_from_db(self.kb_name)
|
||||||
@ -271,6 +279,7 @@ def get_kb_details() -> List[Dict]:
|
|||||||
result[kb] = {
|
result[kb] = {
|
||||||
"kb_name": kb,
|
"kb_name": kb,
|
||||||
"vs_type": "",
|
"vs_type": "",
|
||||||
|
"kb_info": "",
|
||||||
"embed_model": "",
|
"embed_model": "",
|
||||||
"file_count": 0,
|
"file_count": 0,
|
||||||
"create_time": None,
|
"create_time": None,
|
||||||
|
|||||||
98
server/knowledge_base/kb_service/zilliz_kb_service.py
Normal file
98
server/knowledge_base/kb_service/zilliz_kb_service.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from typing import List, Dict, Optional
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.vectorstores import Zilliz
|
||||||
|
from configs import kbs_config
|
||||||
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||||
|
score_threshold_process
|
||||||
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
|
|
||||||
|
|
||||||
|
class ZillizKBService(KBService):
|
||||||
|
zilliz: Zilliz
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_collection(zilliz_name):
|
||||||
|
from pymilvus import Collection
|
||||||
|
return Collection(zilliz_name)
|
||||||
|
|
||||||
|
# def save_vector_store(self):
|
||||||
|
# if self.zilliz.col:
|
||||||
|
# self.zilliz.col.flush()
|
||||||
|
|
||||||
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
|
if self.zilliz.col:
|
||||||
|
data_list = self.zilliz.col.query(expr=f'pk == {id}', output_fields=["*"])
|
||||||
|
if len(data_list) > 0:
|
||||||
|
data = data_list[0]
|
||||||
|
text = data.pop("text")
|
||||||
|
return Document(page_content=text, metadata=data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def search(zilliz_name, content, limit=3):
|
||||||
|
search_params = {
|
||||||
|
"metric_type": "IP",
|
||||||
|
"params": {},
|
||||||
|
}
|
||||||
|
c = ZillizKBService.get_collection(zilliz_name)
|
||||||
|
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
|
||||||
|
|
||||||
|
def do_create_kb(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def vs_type(self) -> str:
|
||||||
|
return SupportedVSType.ZILLIZ
|
||||||
|
|
||||||
|
def _load_zilliz(self, embeddings: Embeddings = None):
|
||||||
|
if embeddings is None:
|
||||||
|
embeddings = self._load_embeddings()
|
||||||
|
zilliz_args = kbs_config.get("zilliz")
|
||||||
|
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(embeddings),
|
||||||
|
collection_name=self.kb_name, connection_args=zilliz_args)
|
||||||
|
|
||||||
|
|
||||||
|
def do_init(self):
|
||||||
|
self._load_zilliz()
|
||||||
|
|
||||||
|
def do_drop_kb(self):
|
||||||
|
if self.zilliz.col:
|
||||||
|
self.zilliz.col.release()
|
||||||
|
self.zilliz.col.drop()
|
||||||
|
|
||||||
|
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||||
|
self._load_zilliz(embeddings=EmbeddingsFunAdapter(embeddings))
|
||||||
|
return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k))
|
||||||
|
|
||||||
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
|
for doc in docs:
|
||||||
|
for k, v in doc.metadata.items():
|
||||||
|
doc.metadata[k] = str(v)
|
||||||
|
for field in self.zilliz.fields:
|
||||||
|
doc.metadata.setdefault(field, "")
|
||||||
|
doc.metadata.pop(self.zilliz._text_field, None)
|
||||||
|
doc.metadata.pop(self.zilliz._vector_field, None)
|
||||||
|
|
||||||
|
ids = self.zilliz.add_documents(docs)
|
||||||
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||||
|
return doc_infos
|
||||||
|
|
||||||
|
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||||
|
if self.zilliz.col:
|
||||||
|
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||||
|
delete_list = [item.get("pk") for item in
|
||||||
|
self.zilliz.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||||
|
self.zilliz.col.delete(expr=f'pk in {delete_list}')
|
||||||
|
|
||||||
|
def do_clear_vs(self):
|
||||||
|
if self.zilliz.col:
|
||||||
|
self.do_drop_kb()
|
||||||
|
self.do_init()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
from server.db.base import Base, engine
|
||||||
|
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
zillizService = ZillizKBService("test")
|
||||||
|
|
||||||
@ -37,9 +37,10 @@ def folder2db(
|
|||||||
kb_names: List[str],
|
kb_names: List[str],
|
||||||
mode: Literal["recreate_vs", "update_in_db", "increament"],
|
mode: Literal["recreate_vs", "update_in_db", "increament"],
|
||||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||||
|
kb_info: dict[str, Any] = {},
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
chunk_overlap: int = CHUNK_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|||||||
@ -32,7 +32,6 @@ def list_config_models() -> BaseResponse:
|
|||||||
从本地获取configs中配置的模型列表
|
从本地获取configs中配置的模型列表
|
||||||
'''
|
'''
|
||||||
configs = list_config_llm_models()
|
configs = list_config_llm_models()
|
||||||
|
|
||||||
# 删除ONLINE_MODEL配置中的敏感信息
|
# 删除ONLINE_MODEL配置中的敏感信息
|
||||||
for config in configs["online"].values():
|
for config in configs["online"].values():
|
||||||
del_keys = set(["worker_class"])
|
del_keys = set(["worker_class"])
|
||||||
|
|||||||
@ -389,15 +389,16 @@ def webui_address() -> str:
|
|||||||
return f"http://{host}:{port}"
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_template(name: str) -> Optional[str]:
|
def get_prompt_template(type:str,name: str) -> Optional[str]:
|
||||||
'''
|
'''
|
||||||
从prompt_config中加载模板内容
|
从prompt_config中加载模板内容
|
||||||
|
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from configs import prompt_config
|
from configs import prompt_config
|
||||||
import importlib
|
import importlib
|
||||||
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
||||||
|
return prompt_config.PROMPT_TEMPLATES[type].get(name)
|
||||||
return prompt_config.PROMPT_TEMPLATES.get(name)
|
|
||||||
|
|
||||||
|
|
||||||
def set_httpx_config(
|
def set_httpx_config(
|
||||||
@ -409,6 +410,7 @@ def set_httpx_config(
|
|||||||
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
|
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
|
||||||
对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。
|
对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|||||||
@ -137,6 +137,14 @@ def test_search_docs(api="/knowledge_base/search_docs"):
|
|||||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_info(api="/knowledge_base/update_info"):
|
||||||
|
url = api_base_url + api
|
||||||
|
print("\n更新知识库介绍")
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": "samples", "kb_info": "你好"})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
|
||||||
def test_update_docs(api="/knowledge_base/update_docs"):
|
def test_update_docs(api="/knowledge_base/update_docs"):
|
||||||
url = api_base_url + api
|
url = api_base_url + api
|
||||||
|
|
||||||
|
|||||||
@ -3,10 +3,9 @@ from webui_pages.utils import *
|
|||||||
from streamlit_chatbox import *
|
from streamlit_chatbox import *
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
"img",
|
"img",
|
||||||
@ -47,11 +46,11 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
|
|||||||
|
|
||||||
if LLM_MODEL in running_models:
|
if LLM_MODEL in running_models:
|
||||||
return LLM_MODEL, True
|
return LLM_MODEL, True
|
||||||
|
|
||||||
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
local_models = [k for k, v in running_models.items() if not v.get("online_api")]
|
||||||
if local_models:
|
if local_models:
|
||||||
return local_models[0], True
|
return local_models[0], True
|
||||||
|
|
||||||
return running_models[0], False
|
return running_models[0], False
|
||||||
|
|
||||||
|
|
||||||
@ -94,15 +93,14 @@ def dialogue_page(api: ApiRequest):
|
|||||||
running_models = list(api.list_running_models())
|
running_models = list(api.list_running_models())
|
||||||
available_models = []
|
available_models = []
|
||||||
config_models = api.list_config_models()
|
config_models = api.list_config_models()
|
||||||
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
|
worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型
|
||||||
for m in worker_models:
|
for m in worker_models:
|
||||||
if m not in running_models and m != "default":
|
if m not in running_models and m != "default":
|
||||||
available_models.append(m)
|
available_models.append(m)
|
||||||
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT)
|
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT)
|
||||||
if not v.get("provider") and k not in running_models:
|
if not v.get("provider") and k not in running_models:
|
||||||
print(k, v)
|
print(k, v)
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
|
|
||||||
llm_models = running_models + available_models
|
llm_models = running_models + available_models
|
||||||
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
|
index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0]))
|
||||||
llm_model = st.selectbox("选择LLM模型:",
|
llm_model = st.selectbox("选择LLM模型:",
|
||||||
@ -124,11 +122,33 @@ def dialogue_page(api: ApiRequest):
|
|||||||
st.success(msg)
|
st.success(msg)
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
st.session_state["prev_llm_model"] = llm_model
|
||||||
|
|
||||||
|
index_prompt = {
|
||||||
|
"LLM 对话": "llm_chat",
|
||||||
|
"自定义Agent问答": "agent_chat",
|
||||||
|
"搜索引擎问答": "search_engine_chat",
|
||||||
|
"知识库问答": "knowledge_base_chat",
|
||||||
|
}
|
||||||
|
prompt_templates_kb_list = list(PROMPT_TEMPLATES[index_prompt[dialogue_mode]].keys())
|
||||||
|
prompt_template_name = prompt_templates_kb_list[0]
|
||||||
|
if "prompt_template_select" not in st.session_state:
|
||||||
|
st.session_state.prompt_template_select = prompt_templates_kb_list[0]
|
||||||
|
|
||||||
|
def prompt_change():
|
||||||
|
text = f"已切换为 {prompt_template_name} 模板。"
|
||||||
|
st.toast(text)
|
||||||
|
|
||||||
|
prompt_template_select = st.selectbox(
|
||||||
|
"请选择Prompt模板:",
|
||||||
|
prompt_templates_kb_list,
|
||||||
|
index=0,
|
||||||
|
on_change=prompt_change,
|
||||||
|
key="prompt_template_select",
|
||||||
|
)
|
||||||
|
prompt_template_name = st.session_state.prompt_template_select
|
||||||
|
|
||||||
temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05)
|
temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05)
|
||||||
|
|
||||||
history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
|
history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
|
||||||
LLM_MODEL_WEBUI = llm_model
|
|
||||||
TEMPERATURE_WEBUI = temperature
|
|
||||||
|
|
||||||
def on_kb_change():
|
def on_kb_change():
|
||||||
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
|
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
|
||||||
@ -144,8 +164,7 @@ def dialogue_page(api: ApiRequest):
|
|||||||
)
|
)
|
||||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
|
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
|
||||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
|
||||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
search_engine_list = api.list_search_engines()
|
search_engine_list = api.list_search_engines()
|
||||||
with st.expander("搜索引擎配置", True):
|
with st.expander("搜索引擎配置", True):
|
||||||
@ -168,7 +187,11 @@ def dialogue_page(api: ApiRequest):
|
|||||||
if dialogue_mode == "LLM 对话":
|
if dialogue_mode == "LLM 对话":
|
||||||
chat_box.ai_say("正在思考...")
|
chat_box.ai_say("正在思考...")
|
||||||
text = ""
|
text = ""
|
||||||
r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature)
|
r = api.chat_chat(prompt,
|
||||||
|
history=history,
|
||||||
|
model=llm_model,
|
||||||
|
prompt_name=prompt_template_name,
|
||||||
|
temperature=temperature)
|
||||||
for t in r:
|
for t in r:
|
||||||
if error_msg := check_error_msg(t): # check whether error occured
|
if error_msg := check_error_msg(t): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
@ -178,37 +201,38 @@ def dialogue_page(api: ApiRequest):
|
|||||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
elif dialogue_mode == "自定义Agent问答":
|
elif dialogue_mode == "自定义Agent问答":
|
||||||
chat_box.ai_say([
|
chat_box.ai_say([
|
||||||
f"正在思考...",
|
f"正在思考...",
|
||||||
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
Markdown("...", in_expander=True, title="思考过程", state="complete"),
|
||||||
|
|
||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
ans = ""
|
ans = ""
|
||||||
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
||||||
if not any(agent in llm_model for agent in support_agent):
|
if not any(agent in llm_model for agent in support_agent):
|
||||||
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,无法正常使用Agent功能!</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! </span> \n\n\n"
|
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,无法正常使用Agent功能!</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! </span> \n\n\n"
|
||||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||||
|
|
||||||
|
|
||||||
for d in api.agent_chat(prompt,
|
for d in api.agent_chat(prompt,
|
||||||
history=history,
|
history=history,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
temperature=temperature):
|
prompt_name=prompt_template_name,
|
||||||
|
temperature=temperature,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
d = json.loads(d)
|
d = json.loads(d)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
|
if chunk := d.get("answer"):
|
||||||
elif chunk := d.get("final_answer"):
|
|
||||||
ans += chunk
|
|
||||||
chat_box.update_msg(ans, element_index=0)
|
|
||||||
elif chunk := d.get("answer"):
|
|
||||||
text += chunk
|
text += chunk
|
||||||
chat_box.update_msg(text, element_index=1)
|
chat_box.update_msg(text, element_index=1)
|
||||||
elif chunk := d.get("tools"):
|
if chunk := d.get("final_answer"):
|
||||||
|
ans += chunk
|
||||||
|
chat_box.update_msg(ans, element_index=0)
|
||||||
|
if chunk := d.get("tools"):
|
||||||
text += "\n\n".join(d.get("tools", []))
|
text += "\n\n".join(d.get("tools", []))
|
||||||
chat_box.update_msg(text, element_index=1)
|
chat_box.update_msg(text, element_index=1)
|
||||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||||
@ -225,6 +249,7 @@ def dialogue_page(api: ApiRequest):
|
|||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
history=history,
|
history=history,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature):
|
temperature=temperature):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
@ -244,6 +269,7 @@ def dialogue_page(api: ApiRequest):
|
|||||||
top_k=se_top_k,
|
top_k=se_top_k,
|
||||||
history=history,
|
history=history,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature):
|
temperature=temperature):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
|
|||||||
@ -63,6 +63,9 @@ def knowledge_base_page(api: ApiRequest):
|
|||||||
else:
|
else:
|
||||||
selected_kb_index = 0
|
selected_kb_index = 0
|
||||||
|
|
||||||
|
if "selected_kb_info" not in st.session_state:
|
||||||
|
st.session_state["selected_kb_info"] = ""
|
||||||
|
|
||||||
def format_selected_kb(kb_name: str) -> str:
|
def format_selected_kb(kb_name: str) -> str:
|
||||||
if kb := kb_list.get(kb_name):
|
if kb := kb_list.get(kb_name):
|
||||||
return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})"
|
return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})"
|
||||||
@ -84,6 +87,11 @@ def knowledge_base_page(api: ApiRequest):
|
|||||||
placeholder="新知识库名称,不支持中文命名",
|
placeholder="新知识库名称,不支持中文命名",
|
||||||
key="kb_name",
|
key="kb_name",
|
||||||
)
|
)
|
||||||
|
kb_info = st.text_input(
|
||||||
|
"知识库简介",
|
||||||
|
placeholder="知识库简介,方便Agent查找",
|
||||||
|
key="kb_info",
|
||||||
|
)
|
||||||
|
|
||||||
cols = st.columns(2)
|
cols = st.columns(2)
|
||||||
|
|
||||||
@ -123,18 +131,23 @@ def knowledge_base_page(api: ApiRequest):
|
|||||||
)
|
)
|
||||||
st.toast(ret.get("msg", " "))
|
st.toast(ret.get("msg", " "))
|
||||||
st.session_state["selected_kb_name"] = kb_name
|
st.session_state["selected_kb_name"] = kb_name
|
||||||
|
st.session_state["selected_kb_info"] = kb_info
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
elif selected_kb:
|
elif selected_kb:
|
||||||
kb = selected_kb
|
kb = selected_kb
|
||||||
|
st.session_state["selected_kb_info"] = kb_list[kb]['kb_info']
|
||||||
|
|
||||||
# 上传文件
|
# 上传文件
|
||||||
files = st.file_uploader("上传知识文件:",
|
files = st.file_uploader("上传知识文件:",
|
||||||
[i for ls in LOADER_DICT.values() for i in ls],
|
[i for ls in LOADER_DICT.values() for i in ls],
|
||||||
accept_multiple_files=True,
|
accept_multiple_files=True,
|
||||||
)
|
)
|
||||||
|
kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, key=None,
|
||||||
|
help=None, on_change=None, args=None, kwargs=None)
|
||||||
|
|
||||||
|
if kb_info != st.session_state["selected_kb_info"]:
|
||||||
|
st.session_state["selected_kb_info"] = kb_info
|
||||||
|
api.update_kb_info(kb, kb_info)
|
||||||
|
|
||||||
# with st.sidebar:
|
# with st.sidebar:
|
||||||
with st.expander(
|
with st.expander(
|
||||||
|
|||||||
@ -279,7 +279,7 @@ class ApiRequest:
|
|||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
prompt_name: str = "llm_chat",
|
prompt_name: str = "default",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
@ -309,6 +309,7 @@ class ApiRequest:
|
|||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/agent_chat 接口
|
对应api.py/chat/agent_chat 接口
|
||||||
@ -320,6 +321,7 @@ class ApiRequest:
|
|||||||
"model_name": model,
|
"model_name": model,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
|
"prompt_name": prompt_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"received input message:")
|
print(f"received input message:")
|
||||||
@ -339,7 +341,7 @@ class ApiRequest:
|
|||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
prompt_name: str = "knowledge_base_chat",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/knowledge_base_chat接口
|
对应api.py/chat/knowledge_base_chat接口
|
||||||
@ -377,7 +379,7 @@ class ApiRequest:
|
|||||||
model: str = LLM_MODEL,
|
model: str = LLM_MODEL,
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
prompt_name: str = "knowledge_base_chat",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/search_engine_chat接口
|
对应api.py/chat/search_engine_chat接口
|
||||||
@ -558,6 +560,22 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
|
|
||||||
|
def update_kb_info(self,knowledge_base_name,kb_info):
|
||||||
|
'''
|
||||||
|
对应api.py/knowledge_base/update_info接口
|
||||||
|
'''
|
||||||
|
data = {
|
||||||
|
"knowledge_base_name": knowledge_base_name,
|
||||||
|
"kb_info": kb_info,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.post(
|
||||||
|
"/knowledge_base/update_info",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def update_kb_docs(
|
def update_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
@ -652,7 +670,7 @@ class ApiRequest:
|
|||||||
|
|
||||||
def get_model_config(
|
def get_model_config(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
'''
|
'''
|
||||||
获取服务器上模型配置
|
获取服务器上模型配置
|
||||||
@ -662,6 +680,7 @@ class ApiRequest:
|
|||||||
}
|
}
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/llm_model/get_model_config",
|
"/llm_model/get_model_config",
|
||||||
|
json=data,
|
||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user