更新Agent工具返回

This commit is contained in:
zR 2023-12-23 20:55:44 +08:00 committed by liunux4odoo
parent e2a46a1d0f
commit 36c90e2e2b
8 changed files with 97 additions and 69 deletions

View File

@ -1,13 +1,11 @@
PROMPT_TEMPLATES = { PROMPT_TEMPLATES = {
"preprocess_model": { "preprocess_model": {
"default": "default":
'你只要回复0 和 1 ,代表不需要使用工具。以下几种问题不需要使用工具:' '你只要回复0 和 1 ,代表不需要使用工具。以下几种问题需要使用工具:'
'1. 介绍一下你自己, 回复0\n' '1. 需要联网查询的内容\n'
'2. 讲一个故事, 回复0\n' '2. 需要计算的内容\n'
'3. 给我开一个玩笑, 回复0\n' '3. 需要查询实时性的内容\n'
'4. 我当前运行的文件夹是, 回复1\n' '如果我的输入满足这几种情况返回1。其他输入请你回复0你只要返回一个数字\n'
'5. where is this cat, 回复1\n'
'6. 介绍一下像极了我大学, 回复1\n'
'这是我的问题:' '这是我的问题:'
}, },
"llm_model": { "llm_model": {

View File

@ -1,7 +1,12 @@
# API requirements # API requirements
langchain>=0.0.350 # On Windows system, install the cuda version manually from https://pytorch.org/
langchain-experimental>=0.0.42 torch>=2.1.2
torchvision>=0.16.2
torchaudio>=2.1.2
langchain>=0.0.352
langchain-experimental>=0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.35 fschat==0.2.35
openai==1.9.0 openai==1.9.0

View File

@ -8,7 +8,6 @@ import logging
from typing import Any, List, Sequence, Tuple, Optional, Union from typing import Any, List, Sequence, Tuple, Optional, Union
from pydantic.schema import model_schema from pydantic.schema import model_schema
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory from langchain.memory import ConversationBufferWindowMemory
from langchain.agents.agent import Agent from langchain.agents.agent import Agent
@ -24,7 +23,6 @@ from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -177,6 +175,7 @@ class StructuredGLM3ChatAgent(Agent):
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
prompt = cls.create_prompt( prompt = cls.create_prompt(
tools, tools,
@ -217,7 +216,7 @@ def initialize_glm3_agent(
) -> AgentExecutor: ) -> AgentExecutor:
tags_ = list(tags) if tags else [] tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {} agent_kwargs = agent_kwargs or {}
llm.callbacks=callbacks llm.callbacks = callbacks
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools( agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
llm=llm, llm=llm,
tools=tools, tools=tools,

View File

@ -154,15 +154,11 @@ def initialize_qwen_agent(
memory: Optional[ConversationBufferWindowMemory] = None, memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None, agent_kwargs: Optional[dict] = None,
*, *,
return_direct: Optional[bool] = None,
tags: Optional[Sequence[str]] = None, tags: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AgentExecutor: ) -> AgentExecutor:
tags_ = list(tags) if tags else [] tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {} agent_kwargs = agent_kwargs or {}
if isinstance(return_direct, bool): # can make all tools return directly
tools = [t.copy(update={"return_direct": return_direct}) for t in tools]
llm.callbacks=callbacks llm.callbacks=callbacks
agent_obj = QwenChatAgent.from_llm_and_tools( agent_obj = QwenChatAgent.from_llm_and_tools(
llm=llm, llm=llm,

View File

@ -24,13 +24,13 @@ all_tools = [
name="shell", name="shell",
description="Use Shell to execute Linux commands", description="Use Shell to execute Linux commands",
args_schema=ShellInput, args_schema=ShellInput,
# return_direct=True, #是否直接返回,不做大模型处理
), ),
StructuredTool.from_function( StructuredTool.from_function(
func=wolfram, func=wolfram,
name="wolfram", name="wolfram",
description="Useful for when you need to calculate difficult formulas", description="Useful for when you need to calculate difficult formulas",
args_schema=WolframInput, args_schema=WolframInput,
), ),
StructuredTool.from_function( StructuredTool.from_function(
func=search_youtube, func=search_youtube,
@ -67,5 +67,6 @@ all_tools = [
name="aqa_processor", name="aqa_processor",
description="use this tool to get answer for audio question", description="use this tool to get answer for audio question",
args_schema=AQAInput, args_schema=AQAInput,
) )
] ]

View File

@ -19,7 +19,9 @@ class AgentStatus:
llm_end: int = 3 llm_end: int = 3
agent_action: int = 4 agent_action: int = 4
agent_finish: int = 5 agent_finish: int = 5
error: int = 6 tool_begin: int = 6
tool_end: int = 7
error: int = 8
class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
@ -31,20 +33,19 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
data = { data = {
"status" : AgentStatus.llm_start, "status": AgentStatus.llm_start,
"text" : "", "text": "",
} }
self.done.clear() self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["Action", "<|observation|>"] special_tokens = ["Action", "<|observation|>"]
for stoken in special_tokens: for stoken in special_tokens:
if stoken in token: if stoken in token:
before_action = token.split(stoken)[0] before_action = token.split(stoken)[0]
data = { data = {
"status" : AgentStatus.llm_new_token, "status": AgentStatus.llm_new_token,
"text": before_action + "\n", "text": before_action + "\n",
} }
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
@ -53,15 +54,29 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
if token is not None and token != "" and self.out: if token is not None and token != "" and self.out:
data = { data = {
"status" : AgentStatus.llm_new_token, "status": AgentStatus.llm_new_token,
"text" : token, "text": token,
} }
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
async def on_chat_model_start( async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
data = {
"status": AgentStatus.llm_end,
"text": "",
}
self.queue.put_nowait(dumps(data))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
data = {
"status": AgentStatus.error,
"text": str(error),
}
self.queue.put_nowait(dumps(data))
async def on_tool_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
messages: List[List], input_str: str,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -69,26 +84,40 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
print("tool_begin")
async def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
data = { data = {
"status" : AgentStatus.llm_start, "status": AgentStatus.tool_end,
"text" : "", "tool_output": output,
} }
self.done.clear() self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
data = { data = {
"status" : AgentStatus.llm_end, "status": AgentStatus.tool_end,
"text" : response.generations[0][0].message.content, "text": error,
}
self.out = True
self.queue.put_nowait(dumps(data))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
data = {
"status" : AgentStatus.error,
"text" : str(error),
} }
self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
async def on_agent_action( async def on_agent_action(
@ -101,9 +130,9 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
data = { data = {
"status" : AgentStatus.agent_action, "status": AgentStatus.agent_action,
"tool_name" : action.tool, "tool_name": action.tool,
"tool_input" : action.tool_input, "tool_input": action.tool_input,
"text": action.log, "text": action.log,
} }
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
@ -113,12 +142,16 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
if "Thought:" in finish.return_values["output"]: if "Thought:" in finish.return_values["output"]:
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "") finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
data = { data = {
"status" : AgentStatus.agent_finish, "status": AgentStatus.agent_finish,
"text" : finish.return_values["output"], "text": finish.return_values["output"],
} }
self.done.set() self.done.set()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
self.out = True

View File

@ -69,6 +69,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
chain = LLMChain( chain = LLMChain(
prompt=chat_prompt, prompt=chat_prompt,
llm=models["llm_model"], llm=models["llm_model"],
callbacks=callbacks,
memory=memory memory=memory
) )
classifier_chain = ( classifier_chain = (
@ -76,8 +77,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
| models["preprocess_model"] | models["preprocess_model"]
| StrOutputParser() | StrOutputParser()
) )
if "action_model" in models and len(tools) > 0:
if "action_model" in models and tools:
if "chatglm3" in models["action_model"].model_name.lower(): if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent( agent_executor = initialize_glm3_agent(
llm=models["action_model"], llm=models["action_model"],
@ -151,7 +151,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
# 从配置中选择工具 # 从配置中选择工具
tools = [tool for tool in all_tools if tool.name in tool_config] tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
# 构建完整的Chain # 构建完整的Chain
full_chain = create_models_chains(prompts=prompts, full_chain = create_models_chains(prompts=prompts,
models=models, models=models,
@ -161,12 +161,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
history=history, history=history,
history_len=history_len, history_len=history_len,
metadata=metadata) metadata=metadata)
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}), callback.done))
# Execute Chain
task = asyncio.create_task(
wrap_done(full_chain.ainvoke({"input": query}), callback.done))
async for chunk in callback.aiter(): async for chunk in callback.aiter():
data = json.loads(chunk) data = json.loads(chunk)
data["message_id"] = message_id data["message_id"] = message_id

View File

@ -272,12 +272,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>', f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
unsafe_allow_html=True) unsafe_allow_html=True)
chat_box.ai_say(["正在思考...", chat_box.ai_say("正在思考...")
Markdown(title="tool call", in_expander=True, expanded=True,state="running"),
Markdown()])
text = "" text = ""
message_id = "" text_action = ""
tool_called = False element_index = 0
for d in api.chat_chat(query=prompt, for d in api.chat_chat(query=prompt,
metadata=files_upload, metadata=files_upload,
@ -290,30 +288,33 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
metadata = { metadata = {
"message_id": message_id, "message_id": message_id,
} }
if not tool_called: # 避免工具调用之后重复输出将LLM输出分为工具调用前后分别处理
element_index = 0
else:
element_index = -1
if d["status"] == AgentStatus.error: if d["status"] == AgentStatus.error:
st.error(d["text"]) st.error(d["text"])
elif d["status"] == AgentStatus.agent_action: elif d["status"] == AgentStatus.agent_action:
formatted_data = { formatted_data = {
"action": d["tool_name"], "Function": d["tool_name"],
"action_input": d["tool_input"] "function_input": d["tool_input"]
} }
element_index += 1
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False) formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
chat_box.update_msg(Markdown(f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"), element_index=1) chat_box.insert_msg(
tool_called = True Markdown(title="Function call", in_expander=True, expanded=True, state="running"))
text = "" text = """\n```{}\n```\n""".format(formatted_json)
chat_box.update_msg(Markdown(text), element_index=element_index)
elif d["status"] == AgentStatus.tool_end:
text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"])
chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete")
elif d["status"] == AgentStatus.llm_new_token: elif d["status"] == AgentStatus.llm_new_token:
text += d["text"] text += d["text"]
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata) chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.llm_end: elif d["status"] == AgentStatus.llm_end:
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.agent_finish: elif d["status"] == AgentStatus.agent_finish:
chat_box.update_msg(element_index=1, state="complete", expanded=False) element_index += 1
chat_box.update_msg(Markdown(d["text"]), streaming=False, element_index=-1)
# print(d["text"])
chat_box.insert_msg(Markdown(d["text"], expanded=True))
chat_box.update_msg(Markdown(d["text"]), element_index=element_index)
if os.path.exists("tmp/image.jpg"): if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file: with open("tmp/image.jpg", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode() encoded_string = base64.b64encode(image_file.read()).decode()