make qwen agent work with langchain>=0.1 (#3228)

This commit is contained in:
liunux4odoo 2024-03-07 19:14:33 +08:00 committed by GitHub
parent 87c912087c
commit 08c155949b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 56 additions and 29 deletions

View File

@ -27,7 +27,7 @@ def agents_registry(
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
# pass
elif "qwen" in llm.model_name.lower():
agent = create_structured_qwen_chat_agent(llm=llm, tools=tools)
return create_structured_qwen_chat_agent(llm=llm, tools=tools, callbacks=callbacks)
else:
if prompt is not None:
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])

View File

@ -1,27 +1,49 @@
from __future__ import annotations
from functools import partial
import json
import logging
from operator import itemgetter
import re
from typing import List, Sequence, Union
from typing import List, Sequence, Union, Tuple, Any
from langchain_core.callbacks import Callbacks
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.agents.agent import RunnableAgent, AgentExecutor
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.callbacks.base import BaseCallbackHandler
from langchain.prompts.chat import BaseChatPromptTemplate
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
HumanMessage, SystemMessage, AIMessage)
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_function
from server.utils import get_prompt_template
logger = logging.getLogger(__name__)
# langchain's AgentRunnable use .stream to make sure .stream_log working.
# but qwen model cannot do tool call with streaming.
# patch it to make qwen lcel agent working
def _plan_without_stream(
self: RunnableAgent,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
return self.runnable.invoke(inputs, config={"callbacks": callbacks})
async def _aplan_without_stream(
self: RunnableAgent,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
return await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
# The template to use
template: str
@ -42,8 +64,16 @@ class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
else:
kwargs["agent_scratchpad"] = ""
# Create a tools variable from the list of tools provided
# kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}. Parameters: {tool.args_schema.dict()}" for tool in self.tools])
kwargs["tools"] = "\n".join([str(format_tool_to_openai_function(tool)) for tool in self.tools])
tools = []
for t in self.tools:
desc = re.sub(r"\n+", " ", t.description)
text = (f"{t.name}: Call this tool to interact with the {t.name} API. What is the {t.name} API useful for?"
f" {desc}"
f" Parameters: {t.args}"
)
tools.append(text)
kwargs["tools"] = "\n\n".join(tools)
# kwargs["tools"] = "\n".join([str(format_tool_to_openai_function(tool)) for tool in self.tools])
# Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
formatted = self.template.format(**kwargs)
@ -90,8 +120,9 @@ class QwenChatAgentOutputParserLC(StructuredChatOutputParser):
def create_structured_qwen_chat_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callbacks: Sequence[Callbacks],
use_custom_prompt: bool = True,
) -> Runnable:
) -> AgentExecutor:
if use_custom_prompt:
prompt = "qwen"
output_parser = QwenChatAgentOutputParserCustom()
@ -99,17 +130,22 @@ def create_structured_qwen_chat_agent(
prompt = "structured-chat-agent"
output_parser = QwenChatAgentOutputParserLC()
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
template = get_prompt_template("action_model", prompt)
prompt = QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
template=template,
tools=tools)
agent = (
RunnablePassthrough.assign(
agent_scratchpad=itemgetter("intermediate_steps")
)
| prompt
| llm.bind(stop="\nObservation:")
| llm.bind(stop=["<|endoftext|>", "<|im_start|>", "<|im_end|>", "\nObservation:"])
| output_parser
)
return agent
executor = AgentExecutor(agent=agent, tools=tools, callbacks=callbacks)
executor.agent.__dict__["plan"] = partial(_plan_without_stream, executor.agent)
executor.agent.__dict__["aplan"] = partial(_aplan_without_stream, executor.agent)
return executor

View File

@ -40,7 +40,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue.put_nowait(dumps(data))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["Action", "<|observation|>"]
special_tokens = ["\nAction:", "\nObservation:", "<|observation|>"]
for stoken in special_tokens:
if stoken in token:
before_action = token.split(stoken)[0]
@ -134,7 +134,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"""Run when tool errors."""
data = {
"status": AgentStatus.tool_end,
"text": error,
"text": str(error),
}
# self.done.clear()
self.queue.put_nowait(dumps(data))

View File

@ -6,11 +6,10 @@ import asyncio
import json
from pprint import pprint
from langchain.agents import AgentExecutor
from langchain_openai.chat_models import ChatOpenAI
# from langchain.chat_models.openai import ChatOpenAI
from server.agent.tools_factory.tools_registry import all_tools
from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_agent
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
from server.utils import get_ChatOpenAI
from langchain import globals
# globals.set_debug(True)
@ -19,19 +18,11 @@ from langchain import globals
async def test1():
callback = AgentExecutorAsyncIteratorCallbackHandler()
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
# qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
qwen_model = ChatOpenAI(base_url="http://127.0.0.1:9997/v1",
api_key="empty",
streaming=False,
temperature=0.01,
model="qwen",
callbacks=[callback],
)
agent = create_structured_qwen_chat_agent(tools=tools, llm=qwen_model)
executor = AgentExecutor(agent=agent, tools=tools, verbose=True, callbacks=[callback])
# ret = executor.invoke("苏州今天冷吗")
qwen_model = get_ChatOpenAI("qwen", 0.01, streaming=False, callbacks=[callback])
executor = create_structured_qwen_chat_agent(llm=qwen_model,
tools=all_tools,
callbacks=[callback])
# ret = executor.invoke({"input": "苏州今天冷吗"})
ret = asyncio.create_task(executor.ainvoke({"input": "苏州今天冷吗"}))
async for chunk in callback.aiter():
print(chunk)
@ -137,4 +128,4 @@ async def test_text2image():
x = json.loads(x)
pprint(x)
asyncio.run(test_text2image())
asyncio.run(test1())