更新chatGLM3 langchain0.1.x Agent写法

This commit is contained in:
zR 2024-01-07 23:02:46 +08:00 committed by liunux4odoo
parent 5df19d907b
commit df1a508e10
7 changed files with 139 additions and 201 deletions

View File

@ -1,13 +1,15 @@
# API requirements # API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/ # Torch requiremnts, install the cuda version manually from https://pytorch.org/
torch>=2.1.2 torch>=2.1.2
torchvision>=0.16.2 torchvision>=0.16.2
torchaudio>=2.1.2 torchaudio>=2.1.2
# Langchain 0.1.x requirements
langchain>=0.1.0 langchain>=0.1.0
langchain_openai>=0.0.2 langchain_openai>=0.0.2
langchain-community>=1.0.0 langchain-community>=1.0.0
langchainhub>=0.1.14
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.35 fschat==0.2.35

View File

@ -1,2 +1,2 @@
from .glm3_agent import initialize_glm3_agent from .glm3_agent import create_structured_glm3_chat_agent
from .qwen_agent import initialize_qwen_agent from .qwen_agent import initialize_qwen_agent

View File

@ -0,0 +1,34 @@
from typing import List, Sequence
from langchain import hub
from langchain.agents import AgentExecutor, create_structured_chat_agent
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import BaseTool
from server.agent.agent_factory import create_structured_glm3_chat_agent
def agents_registry(
llm: BaseLanguageModel,
tools: Sequence[BaseTool] = [],
callbacks: List[BaseCallbackHandler] = [],
prompt: str = None,
verbose: bool = False):
if prompt is not None:
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
else:
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
# Write any optimized method here.
if "glm3" in llm.model_name.lower():
# An optimized method of langchain Agent that uses the glm3 series model
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
else:
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks)
return agent_executor

View File

@ -2,41 +2,45 @@
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo. This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
""" """
from __future__ import annotations
import json import json
import logging import logging
from typing import Any, List, Sequence, Tuple, Optional, Union from typing import Sequence, Optional, Union
from pydantic.schema import model_schema
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser import langchain_core.prompts
from langchain.memory import ConversationBufferWindowMemory import langchain_core.messages
from langchain.agents.agent import Agent from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.agents.agent import AgentOutputParser from langchain.agents.agent import AgentOutputParser
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.prompts.chat import ChatPromptTemplate
from langchain.output_parsers import OutputFixingParser from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.pydantic_v1 import Field
from pydantic import typing
from pydantic.schema import model_schema
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SYSTEM_PROMPT = "Answer the following questions as best as you can. You have access to the following tools:\n{tools}"
HUMAN_MESSAGE = "Let's start! Human:{input}\n\n{agent_scratchpad}"
class StructuredChatOutputParserWithRetries(AgentOutputParser):
"""Output parser with retries for the structured chat agent."""
class StructuredGLM3ChatOutputParser(AgentOutputParser):
"""
Output parser with retries for the structured chat agent.
"""
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser) base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
"""The base parser to use."""
output_fixing_parser: Optional[OutputFixingParser] = None output_fixing_parser: Optional[OutputFixingParser] = None
"""The output fixing parser to use."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
print(text)
special_tokens = ["Action:", "<|observation|>"] special_tokens = ["Action:", "<|observation|>"]
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index] text = text[:first_index]
if "tool_call" in text: if "tool_call" in text:
action_end = text.find("```") action_end = text.find("```")
action = text[:action_end].strip() action = text[:action_end].strip()
@ -74,156 +78,64 @@ Action:
@property @property
def _type(self) -> str: def _type(self) -> str:
return "structured_chat_ChatGLM3_6b_with_retries" return "StructuredGLM3ChatOutputParser"
class StructuredGLM3ChatAgent(Agent): def create_structured_glm3_chat_agent(
"""Structured Chat Agent.""" llm: BaseLanguageModel, tools: Sequence[BaseTool]
) -> Runnable:
tools_json = []
for tool in tools:
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
description = tool.description.split(" - ")[
1].strip() if tool.description and " - " in tool.description else tool.description
parameters = {k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'title'} for k, v in
tool_schema.get("properties", {}).items()}
simplified_config_langchain = {
"name": tool.name,
"description": description,
"parameters": parameters
}
tools_json.append(simplified_config_langchain)
tools = "\n".join([str(tool) for tool in tools_json])
output_parser: AgentOutputParser = Field( prompt = ChatPromptTemplate(
default_factory=StructuredChatOutputParserWithRetries input_variables=["input", "agent_scratchpad"],
) input_types={'chat_history': typing.List[typing.Union[
"""Output parser for the agent.""" langchain_core.messages.ai.AIMessage,
langchain_core.messages.human.HumanMessage,
@property langchain_core.messages.chat.ChatMessage,
def observation_prefix(self) -> str: langchain_core.messages.system.SystemMessage,
"""Prefix to append the ChatGLM3-6B observation with.""" langchain_core.messages.function.FunctionMessage,
return "Observation:" langchain_core.messages.tool.ToolMessage]]
},
@property messages=[
def llm_prefix(self) -> str: langchain_core.prompts.SystemMessagePromptTemplate(
"""Prefix to append the llm call with.""" prompt=langchain_core.prompts.PromptTemplate(
return "Thought:" input_variables=['tools'],
template=SYSTEM_PROMPT)
def _construct_scratchpad( ),
self, intermediate_steps: List[Tuple[AgentAction, str]] langchain_core.prompts.MessagesPlaceholder(
) -> str: variable_name='chat_history',
agent_scratchpad = super()._construct_scratchpad(intermediate_steps) optional=True
if not isinstance(agent_scratchpad, str): ),
raise ValueError("agent_scratchpad should be of type string.") langchain_core.prompts.HumanMessagePromptTemplate(
if agent_scratchpad: prompt=langchain_core.prompts.PromptTemplate(
return ( input_variables=['agent_scratchpad', 'input'],
f"This was your previous work " template=HUMAN_MESSAGE
f"(but I haven't seen any of it! I only see what " )
f"you return as final answer):\n{agent_scratchpad}"
) )
else:
return agent_scratchpad
@classmethod
def _get_default_output_parser(
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
) -> AgentOutputParser:
return StructuredChatOutputParserWithRetries(llm=llm)
@property
def _stop(self) -> List[str]:
return ["<|observation|>"]
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prompt: str = None,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tools_json = []
tool_names = []
for tool in tools:
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
simplified_config_langchain = {
"name": tool.name,
"description": tool.description,
"parameters": tool_schema.get("properties", {})
}
tools_json.append(simplified_config_langchain)
tool_names.append(tool.name)
formatted_tools = "\n".join([
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
for tool in tools_json
])
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
template = prompt.format(tool_names=tool_names,
tools=formatted_tools,
input="{input}",
agent_scratchpad="{agent_scratchpad}")
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
] ]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod ).partial(tools=tools)
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
output_parser: Optional[AgentOutputParser] = None,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools) llm_with_stop = llm.bind(stop=["<|observation|>"])
prompt = cls.create_prompt( agent = (
tools, RunnablePassthrough.assign(
prompt=prompt, agent_scratchpad=lambda x: x["intermediate_steps"],
input_variables=input_variables, )
memory_prompts=memory_prompts, | prompt
) | llm_with_stop
llm_chain = LLMChain( | StructuredGLM3ChatOutputParser()
llm=llm,
prompt=prompt,
callbacks=callbacks,
verbose=True
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)
@property
def _agent_type(self) -> str:
raise ValueError
def initialize_glm3_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
llm.callbacks = callbacks
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
prompt=prompt,
**agent_kwargs
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callbacks=callbacks,
memory=memory,
tags=tags_,
**kwargs,
) )
return agent

View File

@ -95,5 +95,5 @@ def search_internet(query: str):
return search_engine(query=query, config=tool_config) return search_engine(query=query, config=tool_config)
class SearchInternetInput(BaseModel): class SearchInternetInput(BaseModel):
query: str = Field(description="Query for Internet search") query: str = Field(description="query for Internet search")

View File

@ -102,7 +102,8 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
print("tool_start") pass
async def on_tool_end( async def on_tool_end(
self, self,

View File

@ -1,19 +1,16 @@
import asyncio import asyncio
import json import json
from typing import List, Union, AsyncIterable, Dict from typing import AsyncIterable, List, Union, Dict
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from langchain.agents import initialize_agent, AgentType, create_structured_chat_agent, AgentExecutor
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableBranch from server.agent.agent_factory.agents_registry import agents_registry
from server.agent.agent_factory import initialize_glm3_agent, initialize_qwen_agent
from server.agent.tools_factory.tools_registry import all_tools from server.agent.tools_factory.tools_registry import all_tools
from server.agent.container import container from server.agent.container import container
@ -21,7 +18,7 @@ from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.chat.utils import History from server.chat.utils import History
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import AgentStatus, AgentExecutorAsyncIteratorCallbackHandler from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
def create_models_from_config(configs, callbacks, stream): def create_models_from_config(configs, callbacks, stream):
@ -46,7 +43,6 @@ def create_models_from_config(configs, callbacks, stream):
return models, prompts return models, prompts
# 在这里写构建逻辑
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata): def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
memory = None memory = None
chat_prompt = None chat_prompt = None
@ -78,38 +74,21 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
| models["preprocess_model"] | models["preprocess_model"]
| StrOutputParser() | StrOutputParser()
) )
if "action_model" in models and len(tools) > 0:
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
callbacks=callbacks,
verbose=True,
)
elif "qwen" in models["action_model"].model_name.lower():
agent_executor = initialize_qwen_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
callbacks=callbacks,
verbose=True,
)
else:
agent_executor = initialize_agent(
llm=models["action_model"],
tools=tools,
callbacks=callbacks,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
if "action_model" in models and tools is not None:
agent_executor = agents_registry(
llm=models["action_model"],
callbacks=callbacks,
tools=tools,
prompt=None,
verbose=True
)
# branch = RunnableBranch( # branch = RunnableBranch(
# (lambda x: "1" in x["topic"].lower(), agent_executor), # (lambda x: "1" in x["topic"].lower(), agent_executor),
# chain # chain
# ) # )
# full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) # full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
full_chain = ({"input": lambda x: x["input"], } | agent_executor) full_chain = ({"input": lambda x: x["input"]} | agent_executor)
else: else:
chain.llm.callbacks = callbacks chain.llm.callbacks = callbacks
full_chain = ({"input": lambda x: x["input"]} | chain) full_chain = ({"input": lambda x: x["input"]} | chain)
@ -155,7 +134,17 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
history=history, history=history,
history_len=history_len, history_len=history_len,
metadata=metadata) metadata=metadata)
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}), callback.done)) task = asyncio.create_task(wrap_done(
full_chain.ainvoke(
{
"input": query,
"chat_history": [
HumanMessage(content="今天北京的温度是多少度"),
AIMessage(content="今天北京的温度是1度"),
],
}
),callback.done))
async for chunk in callback.aiter(): async for chunk in callback.aiter():
data = json.loads(chunk) data = json.loads(chunk)