2024-03-22 20:39:25 +08:00

140 lines
5.2 KiB
Python

"""
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
"""
import json
import logging
from typing import Sequence, Optional, Union
import langchain_core.prompts
import langchain_core.messages
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.agents.agent import AgentOutputParser
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.prompts.chat import ChatPromptTemplate
from langchain.output_parsers import OutputFixingParser
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from server.pydantic_v1 import Field, typing, model_schema
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = "Answer the following questions as best as you can. You have access to the following tools:\n{tools}"
HUMAN_MESSAGE = "Let's start! Human:{input}\n\n{agent_scratchpad}"
class StructuredGLM3ChatOutputParser(AgentOutputParser):
"""
Output parser with retries for the structured chat agent.
"""
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
output_fixing_parser: Optional[OutputFixingParser] = None
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
print(text)
special_tokens = ["Action:", "<|observation|>"]
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index]
if "tool_call" in text:
action_end = text.find("```")
action = text[:action_end].strip()
params_str_start = text.find("(") + 1
params_str_end = text.rfind(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
action_json = {
"action": action,
"action_input": params
}
else:
action_json = {
"action": "Final Answer",
"action_input": text
}
action_str = f"""
Action:
```
{json.dumps(action_json, ensure_ascii=False)}
```"""
try:
if self.output_fixing_parser is not None:
parsed_obj: Union[
AgentAction, AgentFinish
] = self.output_fixing_parser.parse(action_str)
else:
parsed_obj = self.base_parser.parse(action_str)
return parsed_obj
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
@property
def _type(self) -> str:
return "StructuredGLM3ChatOutputParser"
def create_structured_glm3_chat_agent(
llm: BaseLanguageModel, tools: Sequence[BaseTool]
) -> Runnable:
tools_json = []
for tool in tools:
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
description = tool.description.split(" - ")[
1].strip() if tool.description and " - " in tool.description else tool.description
parameters = {k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'title'} for k, v in
tool_schema.get("properties", {}).items()}
simplified_config_langchain = {
"name": tool.name,
"description": description,
"parameters": parameters
}
tools_json.append(simplified_config_langchain)
tools = "\n".join([str(tool) for tool in tools_json])
prompt = ChatPromptTemplate(
input_variables=["input", "agent_scratchpad"],
input_types={'chat_history': typing.List[typing.Union[
langchain_core.messages.ai.AIMessage,
langchain_core.messages.human.HumanMessage,
langchain_core.messages.chat.ChatMessage,
langchain_core.messages.system.SystemMessage,
langchain_core.messages.function.FunctionMessage,
langchain_core.messages.tool.ToolMessage]]
},
messages=[
langchain_core.prompts.SystemMessagePromptTemplate(
prompt=langchain_core.prompts.PromptTemplate(
input_variables=['tools'],
template=SYSTEM_PROMPT)
),
langchain_core.prompts.MessagesPlaceholder(
variable_name='chat_history',
optional=True
),
langchain_core.prompts.HumanMessagePromptTemplate(
prompt=langchain_core.prompts.PromptTemplate(
input_variables=['agent_scratchpad', 'input'],
template=HUMAN_MESSAGE
)
)
]
).partial(tools=tools)
llm_with_stop = llm.bind(stop=["<|observation|>"])
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: x["intermediate_steps"],
)
| prompt
| llm_with_stop
| StructuredGLM3ChatOutputParser()
)
return agent