mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-01 20:14:03 +08:00
make qwen agent work with langchain>=0.1 (#3228)
This commit is contained in:
parent
87c912087c
commit
08c155949b
@ -27,7 +27,7 @@ def agents_registry(
|
||||
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
|
||||
# pass
|
||||
elif "qwen" in llm.model_name.lower():
|
||||
agent = create_structured_qwen_chat_agent(llm=llm, tools=tools)
|
||||
return create_structured_qwen_chat_agent(llm=llm, tools=tools, callbacks=callbacks)
|
||||
else:
|
||||
if prompt is not None:
|
||||
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
|
||||
|
||||
@ -1,27 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
import re
|
||||
from typing import List, Sequence, Union
|
||||
from typing import List, Sequence, Union, Tuple, Any
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain.agents.agent import RunnableAgent, AgentExecutor
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.prompts.chat import BaseChatPromptTemplate
|
||||
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
|
||||
HumanMessage, SystemMessage, AIMessage)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.render import format_tool_to_openai_function
|
||||
|
||||
from server.utils import get_prompt_template
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# langchain's AgentRunnable use .stream to make sure .stream_log working.
|
||||
# but qwen model cannot do tool call with streaming.
|
||||
# patch it to make qwen lcel agent working
|
||||
def _plan_without_stream(
|
||||
self: RunnableAgent,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||
return self.runnable.invoke(inputs, config={"callbacks": callbacks})
|
||||
|
||||
async def _aplan_without_stream(
|
||||
self: RunnableAgent,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||
return await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
|
||||
|
||||
|
||||
class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
|
||||
# The template to use
|
||||
template: str
|
||||
@ -42,8 +64,16 @@ class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
|
||||
else:
|
||||
kwargs["agent_scratchpad"] = ""
|
||||
# Create a tools variable from the list of tools provided
|
||||
# kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}. Parameters: {tool.args_schema.dict()}" for tool in self.tools])
|
||||
kwargs["tools"] = "\n".join([str(format_tool_to_openai_function(tool)) for tool in self.tools])
|
||||
tools = []
|
||||
for t in self.tools:
|
||||
desc = re.sub(r"\n+", " ", t.description)
|
||||
text = (f"{t.name}: Call this tool to interact with the {t.name} API. What is the {t.name} API useful for?"
|
||||
f" {desc}"
|
||||
f" Parameters: {t.args}"
|
||||
)
|
||||
tools.append(text)
|
||||
kwargs["tools"] = "\n\n".join(tools)
|
||||
# kwargs["tools"] = "\n".join([str(format_tool_to_openai_function(tool)) for tool in self.tools])
|
||||
# Create a list of tool names for the tools provided
|
||||
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
||||
formatted = self.template.format(**kwargs)
|
||||
@ -90,8 +120,9 @@ class QwenChatAgentOutputParserLC(StructuredChatOutputParser):
|
||||
def create_structured_qwen_chat_agent(
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callbacks: Sequence[Callbacks],
|
||||
use_custom_prompt: bool = True,
|
||||
) -> Runnable:
|
||||
) -> AgentExecutor:
|
||||
if use_custom_prompt:
|
||||
prompt = "qwen"
|
||||
output_parser = QwenChatAgentOutputParserCustom()
|
||||
@ -99,17 +130,22 @@ def create_structured_qwen_chat_agent(
|
||||
prompt = "structured-chat-agent"
|
||||
output_parser = QwenChatAgentOutputParserLC()
|
||||
|
||||
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
||||
|
||||
template = get_prompt_template("action_model", prompt)
|
||||
prompt = QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
|
||||
template=template,
|
||||
tools=tools)
|
||||
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=itemgetter("intermediate_steps")
|
||||
)
|
||||
| prompt
|
||||
| llm.bind(stop="\nObservation:")
|
||||
| llm.bind(stop=["<|endoftext|>", "<|im_start|>", "<|im_end|>", "\nObservation:"])
|
||||
| output_parser
|
||||
)
|
||||
return agent
|
||||
executor = AgentExecutor(agent=agent, tools=tools, callbacks=callbacks)
|
||||
executor.agent.__dict__["plan"] = partial(_plan_without_stream, executor.agent)
|
||||
executor.agent.__dict__["aplan"] = partial(_aplan_without_stream, executor.agent)
|
||||
|
||||
return executor
|
||||
|
||||
@ -40,7 +40,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
special_tokens = ["Action", "<|observation|>"]
|
||||
special_tokens = ["\nAction:", "\nObservation:", "<|observation|>"]
|
||||
for stoken in special_tokens:
|
||||
if stoken in token:
|
||||
before_action = token.split(stoken)[0]
|
||||
@ -134,7 +134,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
"""Run when tool errors."""
|
||||
data = {
|
||||
"status": AgentStatus.tool_end,
|
||||
"text": error,
|
||||
"text": str(error),
|
||||
}
|
||||
# self.done.clear()
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
@ -6,11 +6,10 @@ import asyncio
|
||||
import json
|
||||
from pprint import pprint
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
# from langchain.chat_models.openai import ChatOpenAI
|
||||
from server.agent.tools_factory.tools_registry import all_tools
|
||||
from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_agent
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||
from server.utils import get_ChatOpenAI
|
||||
from langchain import globals
|
||||
|
||||
# globals.set_debug(True)
|
||||
@ -19,19 +18,11 @@ from langchain import globals
|
||||
|
||||
async def test1():
|
||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
|
||||
# qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
|
||||
qwen_model = ChatOpenAI(base_url="http://127.0.0.1:9997/v1",
|
||||
api_key="empty",
|
||||
streaming=False,
|
||||
temperature=0.01,
|
||||
model="qwen",
|
||||
callbacks=[callback],
|
||||
)
|
||||
agent = create_structured_qwen_chat_agent(tools=tools, llm=qwen_model)
|
||||
executor = AgentExecutor(agent=agent, tools=tools, verbose=True, callbacks=[callback])
|
||||
|
||||
# ret = executor.invoke("苏州今天冷吗")
|
||||
qwen_model = get_ChatOpenAI("qwen", 0.01, streaming=False, callbacks=[callback])
|
||||
executor = create_structured_qwen_chat_agent(llm=qwen_model,
|
||||
tools=all_tools,
|
||||
callbacks=[callback])
|
||||
# ret = executor.invoke({"input": "苏州今天冷吗"})
|
||||
ret = asyncio.create_task(executor.ainvoke({"input": "苏州今天冷吗"}))
|
||||
async for chunk in callback.aiter():
|
||||
print(chunk)
|
||||
@ -137,4 +128,4 @@ async def test_text2image():
|
||||
x = json.loads(x)
|
||||
pprint(x)
|
||||
|
||||
asyncio.run(test_text2image())
|
||||
asyncio.run(test1())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user