mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
更新chatGLM3 langchain0.1.x Agent写法
This commit is contained in:
parent
5df19d907b
commit
df1a508e10
@ -1,13 +1,15 @@
|
||||
# API requirements
|
||||
|
||||
# On Windows system, install the cuda version manually from https://pytorch.org/
|
||||
# Torch requiremnts, install the cuda version manually from https://pytorch.org/
|
||||
torch>=2.1.2
|
||||
torchvision>=0.16.2
|
||||
torchaudio>=2.1.2
|
||||
|
||||
# Langchain 0.1.x requirements
|
||||
langchain>=0.1.0
|
||||
langchain_openai>=0.0.2
|
||||
langchain-community>=1.0.0
|
||||
langchainhub>=0.1.14
|
||||
|
||||
pydantic==1.10.13
|
||||
fschat==0.2.35
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
from .glm3_agent import initialize_glm3_agent
|
||||
from .glm3_agent import create_structured_glm3_chat_agent
|
||||
from .qwen_agent import initialize_qwen_agent
|
||||
|
||||
34
server/agent/agent_factory/agents_registry.py
Normal file
34
server/agent/agent_factory/agents_registry.py
Normal file
@ -0,0 +1,34 @@
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain import hub
|
||||
from langchain.agents import AgentExecutor, create_structured_chat_agent
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from server.agent.agent_factory import create_structured_glm3_chat_agent
|
||||
|
||||
|
||||
def agents_registry(
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool] = [],
|
||||
callbacks: List[BaseCallbackHandler] = [],
|
||||
prompt: str = None,
|
||||
verbose: bool = False):
|
||||
if prompt is not None:
|
||||
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
|
||||
else:
|
||||
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
|
||||
|
||||
# Write any optimized method here.
|
||||
if "glm3" in llm.model_name.lower():
|
||||
# An optimized method of langchain Agent that uses the glm3 series model
|
||||
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
|
||||
else:
|
||||
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
|
||||
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks)
|
||||
|
||||
return agent_executor
|
||||
@ -2,41 +2,45 @@
|
||||
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||
from pydantic.schema import model_schema
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||
import langchain_core.prompts
|
||||
import langchain_core.messages
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.output_parsers import OutputFixingParser
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.pydantic_v1 import Field
|
||||
|
||||
from pydantic import typing
|
||||
from pydantic.schema import model_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = "Answer the following questions as best as you can. You have access to the following tools:\n{tools}"
|
||||
HUMAN_MESSAGE = "Let's start! Human:{input}\n\n{agent_scratchpad}"
|
||||
|
||||
class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||
"""Output parser with retries for the structured chat agent."""
|
||||
|
||||
class StructuredGLM3ChatOutputParser(AgentOutputParser):
|
||||
"""
|
||||
Output parser with retries for the structured chat agent.
|
||||
"""
|
||||
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
|
||||
"""The base parser to use."""
|
||||
output_fixing_parser: Optional[OutputFixingParser] = None
|
||||
"""The output fixing parser to use."""
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
print(text)
|
||||
|
||||
special_tokens = ["Action:", "<|observation|>"]
|
||||
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
|
||||
text = text[:first_index]
|
||||
|
||||
if "tool_call" in text:
|
||||
action_end = text.find("```")
|
||||
action = text[:action_end].strip()
|
||||
@ -74,156 +78,64 @@ Action:
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "structured_chat_ChatGLM3_6b_with_retries"
|
||||
return "StructuredGLM3ChatOutputParser"
|
||||
|
||||
|
||||
class StructuredGLM3ChatAgent(Agent):
|
||||
"""Structured Chat Agent."""
|
||||
def create_structured_glm3_chat_agent(
|
||||
llm: BaseLanguageModel, tools: Sequence[BaseTool]
|
||||
) -> Runnable:
|
||||
tools_json = []
|
||||
for tool in tools:
|
||||
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
|
||||
description = tool.description.split(" - ")[
|
||||
1].strip() if tool.description and " - " in tool.description else tool.description
|
||||
parameters = {k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'title'} for k, v in
|
||||
tool_schema.get("properties", {}).items()}
|
||||
simplified_config_langchain = {
|
||||
"name": tool.name,
|
||||
"description": description,
|
||||
"parameters": parameters
|
||||
}
|
||||
tools_json.append(simplified_config_langchain)
|
||||
tools = "\n".join([str(tool) for tool in tools_json])
|
||||
|
||||
output_parser: AgentOutputParser = Field(
|
||||
default_factory=StructuredChatOutputParserWithRetries
|
||||
)
|
||||
"""Output parser for the agent."""
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
"""Prefix to append the ChatGLM3-6B observation with."""
|
||||
return "Observation:"
|
||||
|
||||
@property
|
||||
def llm_prefix(self) -> str:
|
||||
"""Prefix to append the llm call with."""
|
||||
return "Thought:"
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
prompt = ChatPromptTemplate(
|
||||
input_variables=["input", "agent_scratchpad"],
|
||||
input_types={'chat_history': typing.List[typing.Union[
|
||||
langchain_core.messages.ai.AIMessage,
|
||||
langchain_core.messages.human.HumanMessage,
|
||||
langchain_core.messages.chat.ChatMessage,
|
||||
langchain_core.messages.system.SystemMessage,
|
||||
langchain_core.messages.function.FunctionMessage,
|
||||
langchain_core.messages.tool.ToolMessage]]
|
||||
},
|
||||
messages=[
|
||||
langchain_core.prompts.SystemMessagePromptTemplate(
|
||||
prompt=langchain_core.prompts.PromptTemplate(
|
||||
input_variables=['tools'],
|
||||
template=SYSTEM_PROMPT)
|
||||
),
|
||||
langchain_core.prompts.MessagesPlaceholder(
|
||||
variable_name='chat_history',
|
||||
optional=True
|
||||
),
|
||||
langchain_core.prompts.HumanMessagePromptTemplate(
|
||||
prompt=langchain_core.prompts.PromptTemplate(
|
||||
input_variables=['agent_scratchpad', 'input'],
|
||||
template=HUMAN_MESSAGE
|
||||
)
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def _get_default_output_parser(
|
||||
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
|
||||
) -> AgentOutputParser:
|
||||
return StructuredChatOutputParserWithRetries(llm=llm)
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
return ["<|observation|>"]
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tools_json = []
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
|
||||
simplified_config_langchain = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool_schema.get("properties", {})
|
||||
}
|
||||
tools_json.append(simplified_config_langchain)
|
||||
tool_names.append(tool.name)
|
||||
formatted_tools = "\n".join([
|
||||
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
|
||||
for tool in tools_json
|
||||
])
|
||||
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
|
||||
template = prompt.format(tool_names=tool_names,
|
||||
tools=formatted_tools,
|
||||
input="{input}",
|
||||
agent_scratchpad="{agent_scratchpad}")
|
||||
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
callbacks: List[BaseCallbackHandler] = [],
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
).partial(tools=tools)
|
||||
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prompt=prompt,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callbacks=callbacks,
|
||||
verbose=True
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def initialize_glm3_agent(
|
||||
tools: Sequence[BaseTool],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: str = None,
|
||||
callbacks: List[BaseCallbackHandler] = [],
|
||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||
agent_kwargs: Optional[dict] = None,
|
||||
*,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
tags_ = list(tags) if tags else []
|
||||
agent_kwargs = agent_kwargs or {}
|
||||
llm.callbacks = callbacks
|
||||
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt,
|
||||
**agent_kwargs
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
memory=memory,
|
||||
tags=tags_,
|
||||
**kwargs,
|
||||
llm_with_stop = llm.bind(stop=["<|observation|>"])
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: x["intermediate_steps"],
|
||||
)
|
||||
| prompt
|
||||
| llm_with_stop
|
||||
| StructuredGLM3ChatOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
||||
@ -95,5 +95,5 @@ def search_internet(query: str):
|
||||
return search_engine(query=query, config=tool_config)
|
||||
|
||||
class SearchInternetInput(BaseModel):
|
||||
query: str = Field(description="Query for Internet search")
|
||||
query: str = Field(description="query for Internet search")
|
||||
|
||||
|
||||
@ -102,7 +102,8 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
print("tool_start")
|
||||
pass
|
||||
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
|
||||
@ -1,19 +1,16 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Union, AsyncIterable, Dict
|
||||
from typing import AsyncIterable, List, Union, Dict
|
||||
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from langchain.agents import initialize_agent, AgentType, create_structured_chat_agent, AgentExecutor
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
|
||||
from server.agent.agent_factory import initialize_glm3_agent, initialize_qwen_agent
|
||||
from server.agent.agent_factory.agents_registry import agents_registry
|
||||
from server.agent.tools_factory.tools_registry import all_tools
|
||||
from server.agent.container import container
|
||||
|
||||
@ -21,7 +18,7 @@ from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||
from server.chat.utils import History
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.agent_callback_handler import AgentStatus, AgentExecutorAsyncIteratorCallbackHandler
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||
|
||||
|
||||
def create_models_from_config(configs, callbacks, stream):
|
||||
@ -46,7 +43,6 @@ def create_models_from_config(configs, callbacks, stream):
|
||||
return models, prompts
|
||||
|
||||
|
||||
# 在这里写构建逻辑
|
||||
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
|
||||
memory = None
|
||||
chat_prompt = None
|
||||
@ -78,38 +74,21 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
| models["preprocess_model"]
|
||||
| StrOutputParser()
|
||||
)
|
||||
if "action_model" in models and len(tools) > 0:
|
||||
if "chatglm3" in models["action_model"].model_name.lower():
|
||||
agent_executor = initialize_glm3_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
prompt=prompts["action_model"],
|
||||
callbacks=callbacks,
|
||||
verbose=True,
|
||||
)
|
||||
elif "qwen" in models["action_model"].model_name.lower():
|
||||
agent_executor = initialize_qwen_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
prompt=prompts["action_model"],
|
||||
callbacks=callbacks,
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
agent_executor = initialize_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if "action_model" in models and tools is not None:
|
||||
agent_executor = agents_registry(
|
||||
llm=models["action_model"],
|
||||
callbacks=callbacks,
|
||||
tools=tools,
|
||||
prompt=None,
|
||||
verbose=True
|
||||
)
|
||||
# branch = RunnableBranch(
|
||||
# (lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||
# chain
|
||||
# )
|
||||
# full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||
full_chain = ({"input": lambda x: x["input"], } | agent_executor)
|
||||
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
|
||||
else:
|
||||
chain.llm.callbacks = callbacks
|
||||
full_chain = ({"input": lambda x: x["input"]} | chain)
|
||||
@ -155,7 +134,17 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
history=history,
|
||||
history_len=history_len,
|
||||
metadata=metadata)
|
||||
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}), callback.done))
|
||||
task = asyncio.create_task(wrap_done(
|
||||
full_chain.ainvoke(
|
||||
{
|
||||
"input": query,
|
||||
"chat_history": [
|
||||
HumanMessage(content="今天北京的温度是多少度"),
|
||||
AIMessage(content="今天北京的温度是1度"),
|
||||
],
|
||||
|
||||
}
|
||||
),callback.done))
|
||||
|
||||
async for chunk in callback.aiter():
|
||||
data = json.loads(chunk)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user