mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 23:15:53 +08:00
更新Agent工具返回
This commit is contained in:
parent
e2a46a1d0f
commit
36c90e2e2b
@ -1,13 +1,11 @@
|
|||||||
PROMPT_TEMPLATES = {
|
PROMPT_TEMPLATES = {
|
||||||
"preprocess_model": {
|
"preprocess_model": {
|
||||||
"default":
|
"default":
|
||||||
'你只要回复0 和 1 ,代表不需要使用工具。以下几种问题不需要使用工具:'
|
'你只要回复0 和 1 ,代表不需要使用工具。以下几种问题需要使用工具:'
|
||||||
'1. 介绍一下你自己, 回复0\n'
|
'1. 需要联网查询的内容\n'
|
||||||
'2. 讲一个故事, 回复0\n'
|
'2. 需要计算的内容\n'
|
||||||
'3. 给我开一个玩笑, 回复0\n'
|
'3. 需要查询实时性的内容\n'
|
||||||
'4. 我当前运行的文件夹是, 回复1\n'
|
'如果我的输入满足这几种情况,返回1。其他输入,请你回复0,你只要返回一个数字\n'
|
||||||
'5. where is this cat, 回复1\n'
|
|
||||||
'6. 介绍一下像极了我大学, 回复1\n'
|
|
||||||
'这是我的问题:'
|
'这是我的问题:'
|
||||||
},
|
},
|
||||||
"llm_model": {
|
"llm_model": {
|
||||||
|
|||||||
@ -1,7 +1,12 @@
|
|||||||
# API requirements
|
# API requirements
|
||||||
|
|
||||||
langchain>=0.0.350
|
# On Windows system, install the cuda version manually from https://pytorch.org/
|
||||||
langchain-experimental>=0.0.42
|
torch>=2.1.2
|
||||||
|
torchvision>=0.16.2
|
||||||
|
torchaudio>=2.1.2
|
||||||
|
|
||||||
|
langchain>=0.0.352
|
||||||
|
langchain-experimental>=0.0.47
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
fschat==0.2.35
|
fschat==0.2.35
|
||||||
openai==1.9.0
|
openai==1.9.0
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import logging
|
|||||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||||
from pydantic.schema import model_schema
|
from pydantic.schema import model_schema
|
||||||
|
|
||||||
|
|
||||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||||
from langchain.memory import ConversationBufferWindowMemory
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
from langchain.agents.agent import Agent
|
from langchain.agents.agent import Agent
|
||||||
@ -24,7 +23,6 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
|
|
||||||
|
|
||||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -177,6 +175,7 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
|
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
prompt = cls.create_prompt(
|
prompt = cls.create_prompt(
|
||||||
tools,
|
tools,
|
||||||
@ -217,7 +216,7 @@ def initialize_glm3_agent(
|
|||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
tags_ = list(tags) if tags else []
|
tags_ = list(tags) if tags else []
|
||||||
agent_kwargs = agent_kwargs or {}
|
agent_kwargs = agent_kwargs or {}
|
||||||
llm.callbacks=callbacks
|
llm.callbacks = callbacks
|
||||||
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
|
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
|||||||
@ -154,15 +154,11 @@ def initialize_qwen_agent(
|
|||||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||||
agent_kwargs: Optional[dict] = None,
|
agent_kwargs: Optional[dict] = None,
|
||||||
*,
|
*,
|
||||||
return_direct: Optional[bool] = None,
|
|
||||||
tags: Optional[Sequence[str]] = None,
|
tags: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
tags_ = list(tags) if tags else []
|
tags_ = list(tags) if tags else []
|
||||||
agent_kwargs = agent_kwargs or {}
|
agent_kwargs = agent_kwargs or {}
|
||||||
|
|
||||||
if isinstance(return_direct, bool): # can make all tools return directly
|
|
||||||
tools = [t.copy(update={"return_direct": return_direct}) for t in tools]
|
|
||||||
llm.callbacks=callbacks
|
llm.callbacks=callbacks
|
||||||
agent_obj = QwenChatAgent.from_llm_and_tools(
|
agent_obj = QwenChatAgent.from_llm_and_tools(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
|||||||
@ -24,13 +24,13 @@ all_tools = [
|
|||||||
name="shell",
|
name="shell",
|
||||||
description="Use Shell to execute Linux commands",
|
description="Use Shell to execute Linux commands",
|
||||||
args_schema=ShellInput,
|
args_schema=ShellInput,
|
||||||
# return_direct=True, #是否直接返回,不做大模型处理
|
|
||||||
),
|
),
|
||||||
StructuredTool.from_function(
|
StructuredTool.from_function(
|
||||||
func=wolfram,
|
func=wolfram,
|
||||||
name="wolfram",
|
name="wolfram",
|
||||||
description="Useful for when you need to calculate difficult formulas",
|
description="Useful for when you need to calculate difficult formulas",
|
||||||
args_schema=WolframInput,
|
args_schema=WolframInput,
|
||||||
|
|
||||||
),
|
),
|
||||||
StructuredTool.from_function(
|
StructuredTool.from_function(
|
||||||
func=search_youtube,
|
func=search_youtube,
|
||||||
@ -67,5 +67,6 @@ all_tools = [
|
|||||||
name="aqa_processor",
|
name="aqa_processor",
|
||||||
description="use this tool to get answer for audio question",
|
description="use this tool to get answer for audio question",
|
||||||
args_schema=AQAInput,
|
args_schema=AQAInput,
|
||||||
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -19,7 +19,9 @@ class AgentStatus:
|
|||||||
llm_end: int = 3
|
llm_end: int = 3
|
||||||
agent_action: int = 4
|
agent_action: int = 4
|
||||||
agent_finish: int = 5
|
agent_finish: int = 5
|
||||||
error: int = 6
|
tool_begin: int = 6
|
||||||
|
tool_end: int = 7
|
||||||
|
error: int = 8
|
||||||
|
|
||||||
|
|
||||||
class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||||
@ -31,20 +33,19 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
|
|
||||||
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||||
data = {
|
data = {
|
||||||
"status" : AgentStatus.llm_start,
|
"status": AgentStatus.llm_start,
|
||||||
"text" : "",
|
"text": "",
|
||||||
}
|
}
|
||||||
self.done.clear()
|
self.done.clear()
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
|
|
||||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
special_tokens = ["Action", "<|observation|>"]
|
special_tokens = ["Action", "<|observation|>"]
|
||||||
for stoken in special_tokens:
|
for stoken in special_tokens:
|
||||||
if stoken in token:
|
if stoken in token:
|
||||||
before_action = token.split(stoken)[0]
|
before_action = token.split(stoken)[0]
|
||||||
data = {
|
data = {
|
||||||
"status" : AgentStatus.llm_new_token,
|
"status": AgentStatus.llm_new_token,
|
||||||
"text": before_action + "\n",
|
"text": before_action + "\n",
|
||||||
}
|
}
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
@ -53,15 +54,29 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
|
|
||||||
if token is not None and token != "" and self.out:
|
if token is not None and token != "" and self.out:
|
||||||
data = {
|
data = {
|
||||||
"status" : AgentStatus.llm_new_token,
|
"status": AgentStatus.llm_new_token,
|
||||||
"text" : token,
|
"text": token,
|
||||||
}
|
}
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_chat_model_start(
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
data = {
|
||||||
|
"status": AgentStatus.llm_end,
|
||||||
|
"text": "",
|
||||||
|
}
|
||||||
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
|
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||||
|
data = {
|
||||||
|
"status": AgentStatus.error,
|
||||||
|
"text": str(error),
|
||||||
|
}
|
||||||
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
|
async def on_tool_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
messages: List[List],
|
input_str: str,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -69,26 +84,40 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
print("tool_begin")
|
||||||
|
|
||||||
|
async def on_tool_end(
|
||||||
|
self,
|
||||||
|
output: str,
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool ends running."""
|
||||||
data = {
|
data = {
|
||||||
"status" : AgentStatus.llm_start,
|
"status": AgentStatus.tool_end,
|
||||||
"text" : "",
|
"tool_output": output,
|
||||||
}
|
}
|
||||||
self.done.clear()
|
self.done.clear()
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
async def on_tool_error(
|
||||||
|
self,
|
||||||
|
error: BaseException,
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Run when tool errors."""
|
||||||
data = {
|
data = {
|
||||||
"status" : AgentStatus.llm_end,
|
"status": AgentStatus.tool_end,
|
||||||
"text" : response.generations[0][0].message.content,
|
"text": error,
|
||||||
}
|
|
||||||
self.out = True
|
|
||||||
self.queue.put_nowait(dumps(data))
|
|
||||||
|
|
||||||
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
|
||||||
data = {
|
|
||||||
"status" : AgentStatus.error,
|
|
||||||
"text" : str(error),
|
|
||||||
}
|
}
|
||||||
|
self.done.clear()
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_agent_action(
|
async def on_agent_action(
|
||||||
@ -101,9 +130,9 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
data = {
|
data = {
|
||||||
"status" : AgentStatus.agent_action,
|
"status": AgentStatus.agent_action,
|
||||||
"tool_name" : action.tool,
|
"tool_name": action.tool,
|
||||||
"tool_input" : action.tool_input,
|
"tool_input": action.tool_input,
|
||||||
"text": action.log,
|
"text": action.log,
|
||||||
}
|
}
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
@ -113,12 +142,16 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
if "Thought:" in finish.return_values["output"]:
|
if "Thought:" in finish.return_values["output"]:
|
||||||
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
|
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"status" : AgentStatus.agent_finish,
|
"status": AgentStatus.agent_finish,
|
||||||
"text" : finish.return_values["output"],
|
"text": finish.return_values["output"],
|
||||||
}
|
}
|
||||||
|
|
||||||
self.done.set()
|
self.done.set()
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
|
self.out = True
|
||||||
|
|||||||
@ -69,6 +69,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
chain = LLMChain(
|
chain = LLMChain(
|
||||||
prompt=chat_prompt,
|
prompt=chat_prompt,
|
||||||
llm=models["llm_model"],
|
llm=models["llm_model"],
|
||||||
|
callbacks=callbacks,
|
||||||
memory=memory
|
memory=memory
|
||||||
)
|
)
|
||||||
classifier_chain = (
|
classifier_chain = (
|
||||||
@ -76,8 +77,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
| models["preprocess_model"]
|
| models["preprocess_model"]
|
||||||
| StrOutputParser()
|
| StrOutputParser()
|
||||||
)
|
)
|
||||||
|
if "action_model" in models and len(tools) > 0:
|
||||||
if "action_model" in models and tools:
|
|
||||||
if "chatglm3" in models["action_model"].model_name.lower():
|
if "chatglm3" in models["action_model"].model_name.lower():
|
||||||
agent_executor = initialize_glm3_agent(
|
agent_executor = initialize_glm3_agent(
|
||||||
llm=models["action_model"],
|
llm=models["action_model"],
|
||||||
@ -151,7 +151,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
|
|
||||||
# 从配置中选择工具
|
# 从配置中选择工具
|
||||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||||
|
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
||||||
# 构建完整的Chain
|
# 构建完整的Chain
|
||||||
full_chain = create_models_chains(prompts=prompts,
|
full_chain = create_models_chains(prompts=prompts,
|
||||||
models=models,
|
models=models,
|
||||||
@ -161,12 +161,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
history=history,
|
history=history,
|
||||||
history_len=history_len,
|
history_len=history_len,
|
||||||
metadata=metadata)
|
metadata=metadata)
|
||||||
|
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}), callback.done))
|
||||||
# Execute Chain
|
|
||||||
|
|
||||||
task = asyncio.create_task(
|
|
||||||
wrap_done(full_chain.ainvoke({"input": query}), callback.done))
|
|
||||||
|
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
data["message_id"] = message_id
|
data["message_id"] = message_id
|
||||||
|
|||||||
@ -272,12 +272,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||||
unsafe_allow_html=True)
|
unsafe_allow_html=True)
|
||||||
|
|
||||||
chat_box.ai_say(["正在思考...",
|
chat_box.ai_say("正在思考...")
|
||||||
Markdown(title="tool call", in_expander=True, expanded=True,state="running"),
|
|
||||||
Markdown()])
|
|
||||||
text = ""
|
text = ""
|
||||||
message_id = ""
|
text_action = ""
|
||||||
tool_called = False
|
element_index = 0
|
||||||
|
|
||||||
for d in api.chat_chat(query=prompt,
|
for d in api.chat_chat(query=prompt,
|
||||||
metadata=files_upload,
|
metadata=files_upload,
|
||||||
@ -290,30 +288,33 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
metadata = {
|
metadata = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
if not tool_called: # 避免工具调用之后重复输出,将LLM输出分为工具调用前后分别处理
|
|
||||||
element_index = 0
|
|
||||||
else:
|
|
||||||
element_index = -1
|
|
||||||
if d["status"] == AgentStatus.error:
|
if d["status"] == AgentStatus.error:
|
||||||
st.error(d["text"])
|
st.error(d["text"])
|
||||||
elif d["status"] == AgentStatus.agent_action:
|
elif d["status"] == AgentStatus.agent_action:
|
||||||
formatted_data = {
|
formatted_data = {
|
||||||
"action": d["tool_name"],
|
"Function": d["tool_name"],
|
||||||
"action_input": d["tool_input"]
|
"function_input": d["tool_input"]
|
||||||
}
|
}
|
||||||
|
element_index += 1
|
||||||
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||||
chat_box.update_msg(Markdown(f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"), element_index=1)
|
chat_box.insert_msg(
|
||||||
tool_called = True
|
Markdown(title="Function call", in_expander=True, expanded=True, state="running"))
|
||||||
text = ""
|
text = """\n```{}\n```\n""".format(formatted_json)
|
||||||
|
chat_box.update_msg(Markdown(text), element_index=element_index)
|
||||||
|
elif d["status"] == AgentStatus.tool_end:
|
||||||
|
text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"])
|
||||||
|
chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete")
|
||||||
elif d["status"] == AgentStatus.llm_new_token:
|
elif d["status"] == AgentStatus.llm_new_token:
|
||||||
text += d["text"]
|
text += d["text"]
|
||||||
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
|
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
|
||||||
elif d["status"] == AgentStatus.llm_end:
|
elif d["status"] == AgentStatus.llm_end:
|
||||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||||
elif d["status"] == AgentStatus.agent_finish:
|
elif d["status"] == AgentStatus.agent_finish:
|
||||||
chat_box.update_msg(element_index=1, state="complete", expanded=False)
|
element_index += 1
|
||||||
chat_box.update_msg(Markdown(d["text"]), streaming=False, element_index=-1)
|
|
||||||
|
|
||||||
|
# print(d["text"])
|
||||||
|
chat_box.insert_msg(Markdown(d["text"], expanded=True))
|
||||||
|
chat_box.update_msg(Markdown(d["text"]), element_index=element_index)
|
||||||
if os.path.exists("tmp/image.jpg"):
|
if os.path.exists("tmp/image.jpg"):
|
||||||
with open("tmp/image.jpg", "rb") as image_file:
|
with open("tmp/image.jpg", "rb") as image_file:
|
||||||
encoded_string = base64.b64encode(image_file.read()).decode()
|
encoded_string = base64.b64encode(image_file.read()).decode()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user