mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 07:53:29 +08:00
make qwen agent work with langchain>=0.1 (#3228)
This commit is contained in:
parent
87c912087c
commit
08c155949b
@ -27,7 +27,7 @@ def agents_registry(
|
|||||||
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
|
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
|
||||||
# pass
|
# pass
|
||||||
elif "qwen" in llm.model_name.lower():
|
elif "qwen" in llm.model_name.lower():
|
||||||
agent = create_structured_qwen_chat_agent(llm=llm, tools=tools)
|
return create_structured_qwen_chat_agent(llm=llm, tools=tools, callbacks=callbacks)
|
||||||
else:
|
else:
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
|
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
|
||||||
|
|||||||
@ -1,27 +1,49 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
import re
|
import re
|
||||||
from typing import List, Sequence, Union
|
from typing import List, Sequence, Union, Tuple, Any
|
||||||
|
|
||||||
|
from langchain_core.callbacks import Callbacks
|
||||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||||
|
from langchain.agents.agent import RunnableAgent, AgentExecutor
|
||||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
|
||||||
from langchain.prompts.chat import BaseChatPromptTemplate
|
from langchain.prompts.chat import BaseChatPromptTemplate
|
||||||
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
|
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
|
||||||
HumanMessage, SystemMessage, AIMessage)
|
HumanMessage, SystemMessage, AIMessage)
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.tools.render import format_tool_to_openai_function
|
|
||||||
|
|
||||||
from server.utils import get_prompt_template
|
from server.utils import get_prompt_template
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# langchain's AgentRunnable use .stream to make sure .stream_log working.
|
||||||
|
# but qwen model cannot do tool call with streaming.
|
||||||
|
# patch it to make qwen lcel agent working
|
||||||
|
def _plan_without_stream(
|
||||||
|
self: RunnableAgent,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||||
|
return self.runnable.invoke(inputs, config={"callbacks": callbacks})
|
||||||
|
|
||||||
|
async def _aplan_without_stream(
|
||||||
|
self: RunnableAgent,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||||
|
return await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
|
||||||
|
|
||||||
|
|
||||||
class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
|
class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
|
||||||
# The template to use
|
# The template to use
|
||||||
template: str
|
template: str
|
||||||
@ -42,8 +64,16 @@ class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
|
|||||||
else:
|
else:
|
||||||
kwargs["agent_scratchpad"] = ""
|
kwargs["agent_scratchpad"] = ""
|
||||||
# Create a tools variable from the list of tools provided
|
# Create a tools variable from the list of tools provided
|
||||||
# kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}. Parameters: {tool.args_schema.dict()}" for tool in self.tools])
|
tools = []
|
||||||
kwargs["tools"] = "\n".join([str(format_tool_to_openai_function(tool)) for tool in self.tools])
|
for t in self.tools:
|
||||||
|
desc = re.sub(r"\n+", " ", t.description)
|
||||||
|
text = (f"{t.name}: Call this tool to interact with the {t.name} API. What is the {t.name} API useful for?"
|
||||||
|
f" {desc}"
|
||||||
|
f" Parameters: {t.args}"
|
||||||
|
)
|
||||||
|
tools.append(text)
|
||||||
|
kwargs["tools"] = "\n\n".join(tools)
|
||||||
|
# kwargs["tools"] = "\n".join([str(format_tool_to_openai_function(tool)) for tool in self.tools])
|
||||||
# Create a list of tool names for the tools provided
|
# Create a list of tool names for the tools provided
|
||||||
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
|
||||||
formatted = self.template.format(**kwargs)
|
formatted = self.template.format(**kwargs)
|
||||||
@ -90,8 +120,9 @@ class QwenChatAgentOutputParserLC(StructuredChatOutputParser):
|
|||||||
def create_structured_qwen_chat_agent(
|
def create_structured_qwen_chat_agent(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
|
callbacks: Sequence[Callbacks],
|
||||||
use_custom_prompt: bool = True,
|
use_custom_prompt: bool = True,
|
||||||
) -> Runnable:
|
) -> AgentExecutor:
|
||||||
if use_custom_prompt:
|
if use_custom_prompt:
|
||||||
prompt = "qwen"
|
prompt = "qwen"
|
||||||
output_parser = QwenChatAgentOutputParserCustom()
|
output_parser = QwenChatAgentOutputParserCustom()
|
||||||
@ -99,17 +130,22 @@ def create_structured_qwen_chat_agent(
|
|||||||
prompt = "structured-chat-agent"
|
prompt = "structured-chat-agent"
|
||||||
output_parser = QwenChatAgentOutputParserLC()
|
output_parser = QwenChatAgentOutputParserLC()
|
||||||
|
|
||||||
|
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
|
||||||
|
|
||||||
template = get_prompt_template("action_model", prompt)
|
template = get_prompt_template("action_model", prompt)
|
||||||
prompt = QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
|
prompt = QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
|
||||||
template=template,
|
template=template,
|
||||||
tools=tools)
|
tools=tools)
|
||||||
|
|
||||||
agent = (
|
agent = (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
agent_scratchpad=itemgetter("intermediate_steps")
|
agent_scratchpad=itemgetter("intermediate_steps")
|
||||||
)
|
)
|
||||||
| prompt
|
| prompt
|
||||||
| llm.bind(stop="\nObservation:")
|
| llm.bind(stop=["<|endoftext|>", "<|im_start|>", "<|im_end|>", "\nObservation:"])
|
||||||
| output_parser
|
| output_parser
|
||||||
)
|
)
|
||||||
return agent
|
executor = AgentExecutor(agent=agent, tools=tools, callbacks=callbacks)
|
||||||
|
executor.agent.__dict__["plan"] = partial(_plan_without_stream, executor.agent)
|
||||||
|
executor.agent.__dict__["aplan"] = partial(_aplan_without_stream, executor.agent)
|
||||||
|
|
||||||
|
return executor
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|
||||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
special_tokens = ["Action", "<|observation|>"]
|
special_tokens = ["\nAction:", "\nObservation:", "<|observation|>"]
|
||||||
for stoken in special_tokens:
|
for stoken in special_tokens:
|
||||||
if stoken in token:
|
if stoken in token:
|
||||||
before_action = token.split(stoken)[0]
|
before_action = token.split(stoken)[0]
|
||||||
@ -134,7 +134,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
data = {
|
data = {
|
||||||
"status": AgentStatus.tool_end,
|
"status": AgentStatus.tool_end,
|
||||||
"text": error,
|
"text": str(error),
|
||||||
}
|
}
|
||||||
# self.done.clear()
|
# self.done.clear()
|
||||||
self.queue.put_nowait(dumps(data))
|
self.queue.put_nowait(dumps(data))
|
||||||
|
|||||||
@ -6,11 +6,10 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from langchain.agents import AgentExecutor
|
from langchain.agents import AgentExecutor
|
||||||
from langchain_openai.chat_models import ChatOpenAI
|
|
||||||
# from langchain.chat_models.openai import ChatOpenAI
|
|
||||||
from server.agent.tools_factory.tools_registry import all_tools
|
from server.agent.tools_factory.tools_registry import all_tools
|
||||||
from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_agent
|
from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_agent
|
||||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||||
|
from server.utils import get_ChatOpenAI
|
||||||
from langchain import globals
|
from langchain import globals
|
||||||
|
|
||||||
# globals.set_debug(True)
|
# globals.set_debug(True)
|
||||||
@ -19,19 +18,11 @@ from langchain import globals
|
|||||||
|
|
||||||
async def test1():
|
async def test1():
|
||||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||||
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
|
qwen_model = get_ChatOpenAI("qwen", 0.01, streaming=False, callbacks=[callback])
|
||||||
# qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
|
executor = create_structured_qwen_chat_agent(llm=qwen_model,
|
||||||
qwen_model = ChatOpenAI(base_url="http://127.0.0.1:9997/v1",
|
tools=all_tools,
|
||||||
api_key="empty",
|
callbacks=[callback])
|
||||||
streaming=False,
|
# ret = executor.invoke({"input": "苏州今天冷吗"})
|
||||||
temperature=0.01,
|
|
||||||
model="qwen",
|
|
||||||
callbacks=[callback],
|
|
||||||
)
|
|
||||||
agent = create_structured_qwen_chat_agent(tools=tools, llm=qwen_model)
|
|
||||||
executor = AgentExecutor(agent=agent, tools=tools, verbose=True, callbacks=[callback])
|
|
||||||
|
|
||||||
# ret = executor.invoke("苏州今天冷吗")
|
|
||||||
ret = asyncio.create_task(executor.ainvoke({"input": "苏州今天冷吗"}))
|
ret = asyncio.create_task(executor.ainvoke({"input": "苏州今天冷吗"}))
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
print(chunk)
|
print(chunk)
|
||||||
@ -137,4 +128,4 @@ async def test_text2image():
|
|||||||
x = json.loads(x)
|
x = json.loads(x)
|
||||||
pprint(x)
|
pprint(x)
|
||||||
|
|
||||||
asyncio.run(test_text2image())
|
asyncio.run(test1())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user