更新Agent工具返回

This commit is contained in:
zR 2023-12-23 20:55:44 +08:00 committed by liunux4odoo
parent e2a46a1d0f
commit 36c90e2e2b
8 changed files with 97 additions and 69 deletions

View File

@ -1,13 +1,11 @@
PROMPT_TEMPLATES = {
"preprocess_model": {
"default":
'你只要回复0 和 1 ,代表不需要使用工具。以下几种问题不需要使用工具:'
'1. 介绍一下你自己, 回复0\n'
'2. 讲一个故事, 回复0\n'
'3. 给我开一个玩笑, 回复0\n'
'4. 我当前运行的文件夹是, 回复1\n'
'5. where is this cat, 回复1\n'
'6. 介绍一下像极了我大学, 回复1\n'
'你只要回复0 和 1 ,代表不需要使用工具。以下几种问题需要使用工具:'
'1. 需要联网查询的内容\n'
'2. 需要计算的内容\n'
'3. 需要查询实时性的内容\n'
'如果我的输入满足这几种情况返回1。其他输入请你回复0你只要返回一个数字\n'
'这是我的问题:'
},
"llm_model": {

View File

@ -1,7 +1,12 @@
# API requirements
langchain>=0.0.350
langchain-experimental>=0.0.42
# On Windows system, install the cuda version manually from https://pytorch.org/
torch>=2.1.2
torchvision>=0.16.2
torchaudio>=2.1.2
langchain>=0.0.352
langchain-experimental>=0.0.47
pydantic==1.10.13
fschat==0.2.35
openai==1.9.0

View File

@ -8,7 +8,6 @@ import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
from pydantic.schema import model_schema
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents.agent import Agent
@ -24,7 +23,6 @@ from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain_core.callbacks import Callbacks
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__)
@ -177,6 +175,7 @@ class StructuredGLM3ChatAgent(Agent):
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
@ -217,7 +216,7 @@ def initialize_glm3_agent(
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
llm.callbacks=callbacks
llm.callbacks = callbacks
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,

View File

@ -154,15 +154,11 @@ def initialize_qwen_agent(
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
return_direct: Optional[bool] = None,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
if isinstance(return_direct, bool): # can make all tools return directly
tools = [t.copy(update={"return_direct": return_direct}) for t in tools]
llm.callbacks=callbacks
agent_obj = QwenChatAgent.from_llm_and_tools(
llm=llm,

View File

@ -24,13 +24,13 @@ all_tools = [
name="shell",
description="Use Shell to execute Linux commands",
args_schema=ShellInput,
# return_direct=True, #是否直接返回,不做大模型处理
),
StructuredTool.from_function(
func=wolfram,
name="wolfram",
description="Useful for when you need to calculate difficult formulas",
args_schema=WolframInput,
),
StructuredTool.from_function(
func=search_youtube,
@ -67,5 +67,6 @@ all_tools = [
name="aqa_processor",
description="use this tool to get answer for audio question",
args_schema=AQAInput,
)
]

View File

@ -19,7 +19,9 @@ class AgentStatus:
llm_end: int = 3
agent_action: int = 4
agent_finish: int = 5
error: int = 6
tool_begin: int = 6
tool_end: int = 7
error: int = 8
class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
@ -31,20 +33,19 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
data = {
"status" : AgentStatus.llm_start,
"text" : "",
"status": AgentStatus.llm_start,
"text": "",
}
self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["Action", "<|observation|>"]
for stoken in special_tokens:
if stoken in token:
before_action = token.split(stoken)[0]
data = {
"status" : AgentStatus.llm_new_token,
"status": AgentStatus.llm_new_token,
"text": before_action + "\n",
}
self.queue.put_nowait(dumps(data))
@ -53,15 +54,29 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
if token is not None and token != "" and self.out:
data = {
"status" : AgentStatus.llm_new_token,
"text" : token,
"status": AgentStatus.llm_new_token,
"text": token,
}
self.queue.put_nowait(dumps(data))
async def on_chat_model_start(
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
data = {
"status": AgentStatus.llm_end,
"text": "",
}
self.queue.put_nowait(dumps(data))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
data = {
"status": AgentStatus.error,
"text": str(error),
}
self.queue.put_nowait(dumps(data))
async def on_tool_start(
self,
serialized: Dict[str, Any],
messages: List[List],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -69,26 +84,40 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
print("tool_begin")
async def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
data = {
"status" : AgentStatus.llm_start,
"text" : "",
"status": AgentStatus.tool_end,
"tool_output": output,
}
self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
data = {
"status" : AgentStatus.llm_end,
"text" : response.generations[0][0].message.content,
}
self.out = True
self.queue.put_nowait(dumps(data))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
data = {
"status" : AgentStatus.error,
"text" : str(error),
"status": AgentStatus.tool_end,
"text": error,
}
self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_agent_action(
@ -101,9 +130,9 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
**kwargs: Any,
) -> None:
data = {
"status" : AgentStatus.agent_action,
"tool_name" : action.tool,
"tool_input" : action.tool_input,
"status": AgentStatus.agent_action,
"tool_name": action.tool,
"tool_input": action.tool_input,
"text": action.log,
}
self.queue.put_nowait(dumps(data))
@ -113,12 +142,16 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
if "Thought:" in finish.return_values["output"]:
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
data = {
"status" : AgentStatus.agent_finish,
"text" : finish.return_values["output"],
"status": AgentStatus.agent_finish,
"text": finish.return_values["output"],
}
self.done.set()
self.queue.put_nowait(dumps(data))
self.out = True

View File

@ -69,6 +69,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
chain = LLMChain(
prompt=chat_prompt,
llm=models["llm_model"],
callbacks=callbacks,
memory=memory
)
classifier_chain = (
@ -76,8 +77,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
| models["preprocess_model"]
| StrOutputParser()
)
if "action_model" in models and tools:
if "action_model" in models and len(tools) > 0:
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
@ -151,7 +151,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
# 从配置中选择工具
tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
# 构建完整的Chain
full_chain = create_models_chains(prompts=prompts,
models=models,
@ -161,12 +161,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
history=history,
history_len=history_len,
metadata=metadata)
# Execute Chain
task = asyncio.create_task(
wrap_done(full_chain.ainvoke({"input": query}), callback.done))
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}), callback.done))
async for chunk in callback.aiter():
data = json.loads(chunk)
data["message_id"] = message_id

View File

@ -272,12 +272,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
unsafe_allow_html=True)
chat_box.ai_say(["正在思考...",
Markdown(title="tool call", in_expander=True, expanded=True,state="running"),
Markdown()])
chat_box.ai_say("正在思考...")
text = ""
message_id = ""
tool_called = False
text_action = ""
element_index = 0
for d in api.chat_chat(query=prompt,
metadata=files_upload,
@ -290,30 +288,33 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
metadata = {
"message_id": message_id,
}
if not tool_called: # 避免工具调用之后重复输出将LLM输出分为工具调用前后分别处理
element_index = 0
else:
element_index = -1
if d["status"] == AgentStatus.error:
st.error(d["text"])
elif d["status"] == AgentStatus.agent_action:
formatted_data = {
"action": d["tool_name"],
"action_input": d["tool_input"]
"Function": d["tool_name"],
"function_input": d["tool_input"]
}
element_index += 1
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
chat_box.update_msg(Markdown(f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"), element_index=1)
tool_called = True
text = ""
chat_box.insert_msg(
Markdown(title="Function call", in_expander=True, expanded=True, state="running"))
text = """\n```{}\n```\n""".format(formatted_json)
chat_box.update_msg(Markdown(text), element_index=element_index)
elif d["status"] == AgentStatus.tool_end:
text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"])
chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete")
elif d["status"] == AgentStatus.llm_new_token:
text += d["text"]
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.llm_end:
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.agent_finish:
chat_box.update_msg(element_index=1, state="complete", expanded=False)
chat_box.update_msg(Markdown(d["text"]), streaming=False, element_index=-1)
element_index += 1
# print(d["text"])
chat_box.insert_msg(Markdown(d["text"], expanded=True))
chat_box.update_msg(Markdown(d["text"]), element_index=element_index)
if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode()