按照 langchain 0.1 重写 qwen agent

This commit is contained in:
liunux4odoo 2024-01-10 13:25:39 +08:00
parent df1a508e10
commit 7257521e10
5 changed files with 78 additions and 122 deletions

View File

@ -1,2 +1,2 @@
from .glm3_agent import create_structured_glm3_chat_agent from .glm3_agent import create_structured_glm3_chat_agent
from .qwen_agent import initialize_qwen_agent from .qwen_agent import create_structured_qwen_chat_agent

View File

@ -8,7 +8,8 @@ from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from server.agent.agent_factory import create_structured_glm3_chat_agent from server.agent.agent_factory import (create_structured_glm3_chat_agent,
create_structured_qwen_chat_agent)
def agents_registry( def agents_registry(
@ -22,10 +23,14 @@ def agents_registry(
else: else:
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
# llm.callbacks = callbacks
# Write any optimized method here. # Write any optimized method here.
if "glm3" in llm.model_name.lower(): if "glm3" in llm.model_name.lower():
# An optimized method of langchain Agent that uses the glm3 series model # An optimized method of langchain Agent that uses the glm3 series model
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools) agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
elif "qwen" in llm.model_name.lower():
agent = create_structured_qwen_chat_agent(llm=llm, tools=tools)
else: else:
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt) agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)

View File

@ -1,26 +1,24 @@
from __future__ import annotations from __future__ import annotations
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from typing import Any, List, Sequence, Tuple, Optional, Union
import re
from langchain.agents import Tool
from langchain.agents.agent import LLMSingleActionAgent, AgentOutputParser
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, BaseChatPromptTemplate
import json import json
import logging import logging
from langchain.pydantic_v1 import Field from operator import itemgetter
import re
from typing import List, Sequence, Union
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.callbacks.base import BaseCallbackHandler
from langchain.prompts.chat import BaseChatPromptTemplate
from langchain.schema import (AgentAction, AgentFinish, OutputParserException, from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
HumanMessage, SystemMessage, AIMessage) HumanMessage, SystemMessage, AIMessage)
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_function from langchain.tools.render import format_tool_to_openai_function
from server.utils import get_prompt_template from server.utils import get_prompt_template
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,7 +26,7 @@ class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
# The template to use # The template to use
template: str template: str
# The list of tools available # The list of tools available
tools: List[Tool] tools: List[BaseTool]
def format_messages(self, **kwargs) -> str: def format_messages(self, **kwargs) -> str:
# Get the intermediate steps (AgentAction, Observation tuples) # Get the intermediate steps (AgentAction, Observation tuples)
@ -52,8 +50,8 @@ class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
return [HumanMessage(content=formatted)] return [HumanMessage(content=formatted)]
class QwenChatAgentOutputParser(StructuredChatOutputParser): class QwenChatAgentOutputParserCustom(StructuredChatOutputParser):
"""Output parser with retries for the structured chat agent.""" """Output parser with retries for the structured chat agent with custom qwen prompt."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if s := re.findall(r"\nAction:\s*(.+)\nAction\sInput:\s*(.+)", text, flags=re.DOTALL): if s := re.findall(r"\nAction:\s*(.+)\nAction\sInput:\s*(.+)", text, flags=re.DOTALL):
@ -67,111 +65,51 @@ class QwenChatAgentOutputParser(StructuredChatOutputParser):
@property @property
def _type(self) -> str: def _type(self) -> str:
return "structured_chat_qwen_with_retries" return "StructuredQWenChatOutputParserCustom"
class QwenChatAgent(LLMSingleActionAgent): class QwenChatAgentOutputParserLC(StructuredChatOutputParser):
"""Structured Chat Agent.""" """Output parser with retries for the structured chat agent with standard lc prompt."""
output_parser: AgentOutputParser = Field( def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
default_factory=QwenChatAgentOutputParser if s := re.findall(r"\nAction:\s*```(.+)```", text, flags=re.DOTALL):
) action = json.loads(s[0])
"""Output parser for the agent.""" tool = action.get("action")
if tool == "Final Answer":
return AgentFinish({"output": action.get("action_input", "")}, log=text)
else:
return AgentAction(tool=tool, tool_input=action.get("action_input", {}), log=text)
else:
raise OutputParserException(f"Could not parse LLM output: {text}")
@property @property
def observation_prefix(self) -> str: def _type(self) -> str:
"""Prefix to append the qwen observation with.""" return "StructuredQWenChatOutputParserLC"
return "\nObservation:"
@property
def llm_prefix(self) -> str:
"""Prefix to append the llm call with."""
return "\nThought:"
@classmethod def create_structured_qwen_chat_agent(
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
pass
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prompt: str = None,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[QwenChatAgentPromptTemplate]] = None,
) -> QwenChatAgentPromptTemplate:
template = get_prompt_template("action_model", "qwen")
return QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
template=template,
tools=tools)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
prompt: str = None, use_custom_prompt: bool = True,
callbacks: List[BaseCallbackHandler] = [], ) -> Runnable:
output_parser: Optional[AgentOutputParser] = None, if use_custom_prompt:
human_message_template: str = HUMAN_MESSAGE_TEMPLATE, prompt = "qwen"
input_variables: Optional[List[str]] = None, output_parser = QwenChatAgentOutputParserCustom()
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None, else:
**kwargs: Any, prompt = "structured-chat-agent"
) -> QwenChatAgent: output_parser = QwenChatAgentOutputParserLC()
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prompt=prompt,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callbacks=callbacks,
)
tool_names = [tool.name for tool in tools]
output_parser = output_parser or QwenChatAgentOutputParser()
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=output_parser,
stop=["\nObservation:"],
**kwargs,
)
@property template = get_prompt_template("action_model", prompt)
def _agent_type(self) -> str: prompt = QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
return "qwen_chat_agent" template=template,
tools=tools)
agent = (
def initialize_qwen_agent( RunnablePassthrough.assign(
tools: Sequence[BaseTool], agent_scratchpad=itemgetter("intermediate_steps")
llm: BaseLanguageModel, )
prompt: str = None, | prompt
callbacks: List[BaseCallbackHandler] = [], | llm.bind(stop="\nObservation:")
memory: Optional[ConversationBufferWindowMemory] = None, | output_parser
agent_kwargs: Optional[dict] = None,
*,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
llm.callbacks=callbacks
agent_obj = QwenChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
prompt=prompt,
**agent_kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callbacks=callbacks,
memory=memory,
tags=tags_,
intermediate_steps=[],
**kwargs,
) )
return agent

View File

@ -21,6 +21,11 @@ from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
# from langchain.globals import set_debug, set_verbose
# set_debug(True)
# set_verbose(True)
def create_models_from_config(configs, callbacks, stream): def create_models_from_config(configs, callbacks, stream):
if configs is None: if configs is None:
configs = {} configs = {}

View File

@ -3,31 +3,39 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import asyncio import asyncio
from langchain.agents import AgentExecutor
from langchain_openai.chat_models import ChatOpenAI
# from langchain.chat_models.openai import ChatOpenAI
from server.utils import get_ChatOpenAI from server.utils import get_ChatOpenAI
from server.agent.tools_factory.tools_registry import all_tools from server.agent.tools_factory.tools_registry import all_tools
from server.agent.agent_factory.qwen_agent import initialize_qwen_agent from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_agent
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
from langchain import globals from langchain import globals
# globals.set_debug(True) globals.set_debug(True)
# globals.set_verbose(True) # globals.set_verbose(True)
async def main(): async def main():
callback = AgentExecutorAsyncIteratorCallbackHandler() callback = AgentExecutorAsyncIteratorCallbackHandler()
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools] tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback]) # qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
executor = initialize_qwen_agent(tools=tools, qwen_model = ChatOpenAI(base_url="http://127.0.0.1:9997/v1",
llm=qwen_model, api_key="empty",
callbacks=[callback], streaming=False,
) temperature=0.01,
model="qwen",
callbacks=[callback],
)
agent = create_structured_qwen_chat_agent(tools=tools, llm=qwen_model)
executor = AgentExecutor(agent=agent, tools=tools, verbose=True, callbacks=[callback])
# ret = executor.invoke("苏州今天冷吗") # ret = executor.invoke("苏州今天冷吗")
ret = asyncio.create_task(executor.ainvoke("苏州今天冷吗")) ret = asyncio.create_task(executor.ainvoke({"input": "苏州今天冷吗"}))
async for chunk in callback.aiter(): async for chunk in callback.aiter():
print(chunk) print(chunk)
# ret = executor.invoke("从知识库samples中查询chatchat项目简介") # ret = executor.invoke("从知识库samples中查询chatchat项目简介")
# ret = executor.invoke("chatchat项目主要issue有哪些") # ret = executor.invoke("chatchat项目主要issue有哪些")
print(ret) await ret
asyncio.run(main()) asyncio.run(main())