更新chatGLM3 langchain0.1.x Agent写法

This commit is contained in:
zR 2024-01-07 23:02:46 +08:00 committed by liunux4odoo
parent 5df19d907b
commit df1a508e10
7 changed files with 139 additions and 201 deletions

View File

@ -1,13 +1,15 @@
# API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/
# Torch requiremnts, install the cuda version manually from https://pytorch.org/
torch>=2.1.2
torchvision>=0.16.2
torchaudio>=2.1.2
# Langchain 0.1.x requirements
langchain>=0.1.0
langchain_openai>=0.0.2
langchain-community>=1.0.0
langchainhub>=0.1.14
pydantic==1.10.13
fschat==0.2.35

View File

@ -1,2 +1,2 @@
from .glm3_agent import initialize_glm3_agent
from .glm3_agent import create_structured_glm3_chat_agent
from .qwen_agent import initialize_qwen_agent

View File

@ -0,0 +1,34 @@
from typing import List, Sequence
from langchain import hub
from langchain.agents import AgentExecutor, create_structured_chat_agent
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import BaseTool
from server.agent.agent_factory import create_structured_glm3_chat_agent
def agents_registry(
llm: BaseLanguageModel,
tools: Sequence[BaseTool] = [],
callbacks: List[BaseCallbackHandler] = [],
prompt: str = None,
verbose: bool = False):
if prompt is not None:
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
else:
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
# Write any optimized method here.
if "glm3" in llm.model_name.lower():
# An optimized method of langchain Agent that uses the glm3 series model
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
else:
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks)
return agent_executor

View File

@ -2,41 +2,45 @@
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
"""
from __future__ import annotations
import json
import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
from pydantic.schema import model_schema
from typing import Sequence, Optional, Union
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
import langchain_core.prompts
import langchain_core.messages
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.agents.agent import AgentOutputParser
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.prompts.chat import ChatPromptTemplate
from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.pydantic_v1 import Field
from pydantic import typing
from pydantic.schema import model_schema
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = "Answer the following questions as best as you can. You have access to the following tools:\n{tools}"
HUMAN_MESSAGE = "Let's start! Human:{input}\n\n{agent_scratchpad}"
class StructuredChatOutputParserWithRetries(AgentOutputParser):
"""Output parser with retries for the structured chat agent."""
class StructuredGLM3ChatOutputParser(AgentOutputParser):
"""
Output parser with retries for the structured chat agent.
"""
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
"""The base parser to use."""
output_fixing_parser: Optional[OutputFixingParser] = None
"""The output fixing parser to use."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
print(text)
special_tokens = ["Action:", "<|observation|>"]
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index]
if "tool_call" in text:
action_end = text.find("```")
action = text[:action_end].strip()
@ -74,156 +78,64 @@ Action:
@property
def _type(self) -> str:
return "structured_chat_ChatGLM3_6b_with_retries"
return "StructuredGLM3ChatOutputParser"
class StructuredGLM3ChatAgent(Agent):
"""Structured Chat Agent."""
def create_structured_glm3_chat_agent(
llm: BaseLanguageModel, tools: Sequence[BaseTool]
) -> Runnable:
tools_json = []
for tool in tools:
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
description = tool.description.split(" - ")[
1].strip() if tool.description and " - " in tool.description else tool.description
parameters = {k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'title'} for k, v in
tool_schema.get("properties", {}).items()}
simplified_config_langchain = {
"name": tool.name,
"description": description,
"parameters": parameters
}
tools_json.append(simplified_config_langchain)
tools = "\n".join([str(tool) for tool in tools_json])
output_parser: AgentOutputParser = Field(
default_factory=StructuredChatOutputParserWithRetries
)
"""Output parser for the agent."""
@property
def observation_prefix(self) -> str:
"""Prefix to append the ChatGLM3-6B observation with."""
return "Observation:"
@property
def llm_prefix(self) -> str:
"""Prefix to append the llm call with."""
return "Thought:"
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
prompt = ChatPromptTemplate(
input_variables=["input", "agent_scratchpad"],
input_types={'chat_history': typing.List[typing.Union[
langchain_core.messages.ai.AIMessage,
langchain_core.messages.human.HumanMessage,
langchain_core.messages.chat.ChatMessage,
langchain_core.messages.system.SystemMessage,
langchain_core.messages.function.FunctionMessage,
langchain_core.messages.tool.ToolMessage]]
},
messages=[
langchain_core.prompts.SystemMessagePromptTemplate(
prompt=langchain_core.prompts.PromptTemplate(
input_variables=['tools'],
template=SYSTEM_PROMPT)
),
langchain_core.prompts.MessagesPlaceholder(
variable_name='chat_history',
optional=True
),
langchain_core.prompts.HumanMessagePromptTemplate(
prompt=langchain_core.prompts.PromptTemplate(
input_variables=['agent_scratchpad', 'input'],
template=HUMAN_MESSAGE
)
)
else:
return agent_scratchpad
@classmethod
def _get_default_output_parser(
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
) -> AgentOutputParser:
return StructuredChatOutputParserWithRetries(llm=llm)
@property
def _stop(self) -> List[str]:
return ["<|observation|>"]
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prompt: str = None,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tools_json = []
tool_names = []
for tool in tools:
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
simplified_config_langchain = {
"name": tool.name,
"description": tool.description,
"parameters": tool_schema.get("properties", {})
}
tools_json.append(simplified_config_langchain)
tool_names.append(tool.name)
formatted_tools = "\n".join([
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
for tool in tools_json
])
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
template = prompt.format(tool_names=tool_names,
tools=formatted_tools,
input="{input}",
agent_scratchpad="{agent_scratchpad}")
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
output_parser: Optional[AgentOutputParser] = None,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
).partial(tools=tools)
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prompt=prompt,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callbacks=callbacks,
verbose=True
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)
@property
def _agent_type(self) -> str:
raise ValueError
def initialize_glm3_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
llm.callbacks = callbacks
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
prompt=prompt,
**agent_kwargs
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callbacks=callbacks,
memory=memory,
tags=tags_,
**kwargs,
llm_with_stop = llm.bind(stop=["<|observation|>"])
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: x["intermediate_steps"],
)
| prompt
| llm_with_stop
| StructuredGLM3ChatOutputParser()
)
return agent

View File

@ -95,5 +95,5 @@ def search_internet(query: str):
return search_engine(query=query, config=tool_config)
class SearchInternetInput(BaseModel):
query: str = Field(description="Query for Internet search")
query: str = Field(description="query for Internet search")

View File

@ -102,7 +102,8 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
print("tool_start")
pass
async def on_tool_end(
self,

View File

@ -1,19 +1,16 @@
import asyncio
import json
from typing import List, Union, AsyncIterable, Dict
from typing import AsyncIterable, List, Union, Dict
from fastapi import Body
from fastapi.responses import StreamingResponse
from langchain.agents import initialize_agent, AgentType, create_structured_chat_agent, AgentExecutor
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableBranch
from server.agent.agent_factory import initialize_glm3_agent, initialize_qwen_agent
from server.agent.agent_factory.agents_registry import agents_registry
from server.agent.tools_factory.tools_registry import all_tools
from server.agent.container import container
@ -21,7 +18,7 @@ from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.chat.utils import History
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import AgentStatus, AgentExecutorAsyncIteratorCallbackHandler
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
def create_models_from_config(configs, callbacks, stream):
@ -46,7 +43,6 @@ def create_models_from_config(configs, callbacks, stream):
return models, prompts
# 在这里写构建逻辑
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
memory = None
chat_prompt = None
@ -78,38 +74,21 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
| models["preprocess_model"]
| StrOutputParser()
)
if "action_model" in models and len(tools) > 0:
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
callbacks=callbacks,
verbose=True,
)
elif "qwen" in models["action_model"].model_name.lower():
agent_executor = initialize_qwen_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
callbacks=callbacks,
verbose=True,
)
else:
agent_executor = initialize_agent(
llm=models["action_model"],
tools=tools,
callbacks=callbacks,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
if "action_model" in models and tools is not None:
agent_executor = agents_registry(
llm=models["action_model"],
callbacks=callbacks,
tools=tools,
prompt=None,
verbose=True
)
# branch = RunnableBranch(
# (lambda x: "1" in x["topic"].lower(), agent_executor),
# chain
# )
# full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
full_chain = ({"input": lambda x: x["input"], } | agent_executor)
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
else:
chain.llm.callbacks = callbacks
full_chain = ({"input": lambda x: x["input"]} | chain)
@ -155,7 +134,17 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
history=history,
history_len=history_len,
metadata=metadata)
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}), callback.done))
task = asyncio.create_task(wrap_done(
full_chain.ainvoke(
{
"input": query,
"chat_history": [
HumanMessage(content="今天北京的温度是多少度"),
AIMessage(content="今天北京的温度是1度"),
],
}
),callback.done))
async for chunk in callback.aiter():
data = json.loads(chunk)