按照 langchain 0.1 重写 qwen agent

This commit is contained in:
liunux4odoo 2024-01-10 13:25:39 +08:00
parent df1a508e10
commit 7257521e10
5 changed files with 78 additions and 122 deletions

View File

@ -1,2 +1,2 @@
from .glm3_agent import create_structured_glm3_chat_agent
from .qwen_agent import initialize_qwen_agent
from .qwen_agent import create_structured_qwen_chat_agent

View File

@ -8,7 +8,8 @@ from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import BaseTool
from server.agent.agent_factory import create_structured_glm3_chat_agent
from server.agent.agent_factory import (create_structured_glm3_chat_agent,
create_structured_qwen_chat_agent)
def agents_registry(
@ -22,10 +23,14 @@ def agents_registry(
else:
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
# llm.callbacks = callbacks
# Write any optimized method here.
if "glm3" in llm.model_name.lower():
# An optimized method of langchain Agent that uses the glm3 series model
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
elif "qwen" in llm.model_name.lower():
agent = create_structured_qwen_chat_agent(llm=llm, tools=tools)
else:
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)

View File

@ -1,26 +1,24 @@
from __future__ import annotations
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from typing import Any, List, Sequence, Tuple, Optional, Union
import re
from langchain.agents import Tool
from langchain.agents.agent import LLMSingleActionAgent, AgentOutputParser
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, BaseChatPromptTemplate
import json
import logging
from langchain.pydantic_v1 import Field
from operator import itemgetter
import re
from typing import List, Sequence, Union
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.callbacks.base import BaseCallbackHandler
from langchain.prompts.chat import BaseChatPromptTemplate
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
HumanMessage, SystemMessage, AIMessage)
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_function
from server.utils import get_prompt_template
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__)
@ -28,7 +26,7 @@ class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
# The template to use
template: str
# The list of tools available
tools: List[Tool]
tools: List[BaseTool]
def format_messages(self, **kwargs) -> str:
# Get the intermediate steps (AgentAction, Observation tuples)
@ -52,8 +50,8 @@ class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
return [HumanMessage(content=formatted)]
class QwenChatAgentOutputParser(StructuredChatOutputParser):
"""Output parser with retries for the structured chat agent."""
class QwenChatAgentOutputParserCustom(StructuredChatOutputParser):
"""Output parser with retries for the structured chat agent with custom qwen prompt."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if s := re.findall(r"\nAction:\s*(.+)\nAction\sInput:\s*(.+)", text, flags=re.DOTALL):
@ -67,111 +65,51 @@ class QwenChatAgentOutputParser(StructuredChatOutputParser):
@property
def _type(self) -> str:
return "structured_chat_qwen_with_retries"
return "StructuredQWenChatOutputParserCustom"
class QwenChatAgent(LLMSingleActionAgent):
"""Structured Chat Agent."""
class QwenChatAgentOutputParserLC(StructuredChatOutputParser):
"""Output parser with retries for the structured chat agent with standard lc prompt."""
output_parser: AgentOutputParser = Field(
default_factory=QwenChatAgentOutputParser
)
"""Output parser for the agent."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if s := re.findall(r"\nAction:\s*```(.+)```", text, flags=re.DOTALL):
action = json.loads(s[0])
tool = action.get("action")
if tool == "Final Answer":
return AgentFinish({"output": action.get("action_input", "")}, log=text)
else:
return AgentAction(tool=tool, tool_input=action.get("action_input", {}), log=text)
else:
raise OutputParserException(f"Could not parse LLM output: {text}")
@property
def observation_prefix(self) -> str:
"""Prefix to append the qwen observation with."""
return "\nObservation:"
def _type(self) -> str:
return "StructuredQWenChatOutputParserLC"
@property
def llm_prefix(self) -> str:
"""Prefix to append the llm call with."""
return "\nThought:"
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
pass
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prompt: str = None,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[QwenChatAgentPromptTemplate]] = None,
) -> QwenChatAgentPromptTemplate:
template = get_prompt_template("action_model", "qwen")
return QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
template=template,
tools=tools)
@classmethod
def from_llm_and_tools(
cls,
def create_structured_qwen_chat_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
output_parser: Optional[AgentOutputParser] = None,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None,
**kwargs: Any,
) -> QwenChatAgent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prompt=prompt,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callbacks=callbacks,
)
tool_names = [tool.name for tool in tools]
output_parser = output_parser or QwenChatAgentOutputParser()
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=output_parser,
stop=["\nObservation:"],
**kwargs,
)
use_custom_prompt: bool = True,
) -> Runnable:
if use_custom_prompt:
prompt = "qwen"
output_parser = QwenChatAgentOutputParserCustom()
else:
prompt = "structured-chat-agent"
output_parser = QwenChatAgentOutputParserLC()
@property
def _agent_type(self) -> str:
return "qwen_chat_agent"
template = get_prompt_template("action_model", prompt)
prompt = QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
template=template,
tools=tools)
def initialize_qwen_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
llm.callbacks=callbacks
agent_obj = QwenChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
prompt=prompt,
**agent_kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callbacks=callbacks,
memory=memory,
tags=tags_,
intermediate_steps=[],
**kwargs,
agent = (
RunnablePassthrough.assign(
agent_scratchpad=itemgetter("intermediate_steps")
)
| prompt
| llm.bind(stop="\nObservation:")
| output_parser
)
return agent

View File

@ -21,6 +21,11 @@ from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
# from langchain.globals import set_debug, set_verbose
# set_debug(True)
# set_verbose(True)
def create_models_from_config(configs, callbacks, stream):
if configs is None:
configs = {}

View File

@ -3,31 +3,39 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import asyncio
from langchain.agents import AgentExecutor
from langchain_openai.chat_models import ChatOpenAI
# from langchain.chat_models.openai import ChatOpenAI
from server.utils import get_ChatOpenAI
from server.agent.tools_factory.tools_registry import all_tools
from server.agent.agent_factory.qwen_agent import initialize_qwen_agent
from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_agent
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
from langchain import globals
# globals.set_debug(True)
globals.set_debug(True)
# globals.set_verbose(True)
async def main():
callback = AgentExecutorAsyncIteratorCallbackHandler()
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
executor = initialize_qwen_agent(tools=tools,
llm=qwen_model,
callbacks=[callback],
)
# qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
qwen_model = ChatOpenAI(base_url="http://127.0.0.1:9997/v1",
api_key="empty",
streaming=False,
temperature=0.01,
model="qwen",
callbacks=[callback],
)
agent = create_structured_qwen_chat_agent(tools=tools, llm=qwen_model)
executor = AgentExecutor(agent=agent, tools=tools, verbose=True, callbacks=[callback])
# ret = executor.invoke("苏州今天冷吗")
ret = asyncio.create_task(executor.ainvoke("苏州今天冷吗"))
ret = asyncio.create_task(executor.ainvoke({"input": "苏州今天冷吗"}))
async for chunk in callback.aiter():
print(chunk)
# ret = executor.invoke("从知识库samples中查询chatchat项目简介")
# ret = executor.invoke("chatchat项目主要issue有哪些")
print(ret)
await ret
asyncio.run(main())