diff --git a/requirements.txt b/requirements.txt index 550c2b56..871699c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ # API requirements -# On Windows system, install the cuda version manually from https://pytorch.org/ +# Torch requiremnts, install the cuda version manually from https://pytorch.org/ torch>=2.1.2 torchvision>=0.16.2 torchaudio>=2.1.2 +# Langchain 0.1.x requirements langchain>=0.1.0 langchain_openai>=0.0.2 langchain-community>=1.0.0 +langchainhub>=0.1.14 pydantic==1.10.13 fschat==0.2.35 diff --git a/server/agent/agent_factory/__init__.py b/server/agent/agent_factory/__init__.py index 8052fe5c..2b3a31c3 100644 --- a/server/agent/agent_factory/__init__.py +++ b/server/agent/agent_factory/__init__.py @@ -1,2 +1,2 @@ -from .glm3_agent import initialize_glm3_agent +from .glm3_agent import create_structured_glm3_chat_agent from .qwen_agent import initialize_qwen_agent diff --git a/server/agent/agent_factory/agents_registry.py b/server/agent/agent_factory/agents_registry.py new file mode 100644 index 00000000..5094e32d --- /dev/null +++ b/server/agent/agent_factory/agents_registry.py @@ -0,0 +1,34 @@ +from typing import List, Sequence + +from langchain import hub +from langchain.agents import AgentExecutor, create_structured_chat_agent +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.language_models import BaseLanguageModel +from langchain_core.messages import SystemMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.tools import BaseTool + +from server.agent.agent_factory import create_structured_glm3_chat_agent + + +def agents_registry( + llm: BaseLanguageModel, + tools: Sequence[BaseTool] = [], + callbacks: List[BaseCallbackHandler] = [], + prompt: str = None, + verbose: bool = False): + if prompt is not None: + prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)]) + else: + prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt + + # Write any optimized method here. + if "glm3" in llm.model_name.lower(): + # An optimized method of langchain Agent that uses the glm3 series model + agent = create_structured_glm3_chat_agent(llm=llm, tools=tools) + else: + agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt) + + agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks) + + return agent_executor diff --git a/server/agent/agent_factory/glm3_agent.py b/server/agent/agent_factory/glm3_agent.py index 6a9a2e89..35ab19c7 100644 --- a/server/agent/agent_factory/glm3_agent.py +++ b/server/agent/agent_factory/glm3_agent.py @@ -2,41 +2,45 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo. """ -from __future__ import annotations import json import logging -from typing import Any, List, Sequence, Tuple, Optional, Union -from pydantic.schema import model_schema +from typing import Sequence, Optional, Union -from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser -from langchain.memory import ConversationBufferWindowMemory -from langchain.agents.agent import Agent -from langchain.chains.llm import LLMChain -from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate +import langchain_core.prompts +import langchain_core.messages +from langchain_core.runnables import Runnable, RunnablePassthrough from langchain.agents.agent import AgentOutputParser +from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser +from langchain.prompts.chat import ChatPromptTemplate from langchain.output_parsers import OutputFixingParser -from langchain.pydantic_v1 import Field -from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate -from langchain.agents.agent import AgentExecutor -from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool +from langchain.pydantic_v1 import Field + +from pydantic import typing +from pydantic.schema import model_schema logger = logging.getLogger(__name__) +SYSTEM_PROMPT = "Answer the following questions as best as you can. You have access to the following tools:\n{tools}" +HUMAN_MESSAGE = "Let's start! Human:{input}\n\n{agent_scratchpad}" -class StructuredChatOutputParserWithRetries(AgentOutputParser): - """Output parser with retries for the structured chat agent.""" +class StructuredGLM3ChatOutputParser(AgentOutputParser): + """ + Output parser with retries for the structured chat agent. + """ base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser) - """The base parser to use.""" output_fixing_parser: Optional[OutputFixingParser] = None - """The output fixing parser to use.""" def parse(self, text: str) -> Union[AgentAction, AgentFinish]: + print(text) + special_tokens = ["Action:", "<|observation|>"] first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) text = text[:first_index] + if "tool_call" in text: action_end = text.find("```") action = text[:action_end].strip() @@ -74,156 +78,64 @@ Action: @property def _type(self) -> str: - return "structured_chat_ChatGLM3_6b_with_retries" + return "StructuredGLM3ChatOutputParser" -class StructuredGLM3ChatAgent(Agent): - """Structured Chat Agent.""" +def create_structured_glm3_chat_agent( + llm: BaseLanguageModel, tools: Sequence[BaseTool] +) -> Runnable: + tools_json = [] + for tool in tools: + tool_schema = model_schema(tool.args_schema) if tool.args_schema else {} + description = tool.description.split(" - ")[ + 1].strip() if tool.description and " - " in tool.description else tool.description + parameters = {k: {sub_k: sub_v for sub_k, sub_v in v.items() if sub_k != 'title'} for k, v in + tool_schema.get("properties", {}).items()} + simplified_config_langchain = { + "name": tool.name, + "description": description, + "parameters": parameters + } + tools_json.append(simplified_config_langchain) + tools = "\n".join([str(tool) for tool in tools_json]) - output_parser: AgentOutputParser = Field( - default_factory=StructuredChatOutputParserWithRetries - ) - """Output parser for the agent.""" - - @property - def observation_prefix(self) -> str: - """Prefix to append the ChatGLM3-6B observation with.""" - return "Observation:" - - @property - def llm_prefix(self) -> str: - """Prefix to append the llm call with.""" - return "Thought:" - - def _construct_scratchpad( - self, intermediate_steps: List[Tuple[AgentAction, str]] - ) -> str: - agent_scratchpad = super()._construct_scratchpad(intermediate_steps) - if not isinstance(agent_scratchpad, str): - raise ValueError("agent_scratchpad should be of type string.") - if agent_scratchpad: - return ( - f"This was your previous work " - f"(but I haven't seen any of it! I only see what " - f"you return as final answer):\n{agent_scratchpad}" + prompt = ChatPromptTemplate( + input_variables=["input", "agent_scratchpad"], + input_types={'chat_history': typing.List[typing.Union[ + langchain_core.messages.ai.AIMessage, + langchain_core.messages.human.HumanMessage, + langchain_core.messages.chat.ChatMessage, + langchain_core.messages.system.SystemMessage, + langchain_core.messages.function.FunctionMessage, + langchain_core.messages.tool.ToolMessage]] + }, + messages=[ + langchain_core.prompts.SystemMessagePromptTemplate( + prompt=langchain_core.prompts.PromptTemplate( + input_variables=['tools'], + template=SYSTEM_PROMPT) + ), + langchain_core.prompts.MessagesPlaceholder( + variable_name='chat_history', + optional=True + ), + langchain_core.prompts.HumanMessagePromptTemplate( + prompt=langchain_core.prompts.PromptTemplate( + input_variables=['agent_scratchpad', 'input'], + template=HUMAN_MESSAGE + ) ) - else: - return agent_scratchpad - - @classmethod - def _get_default_output_parser( - cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any - ) -> AgentOutputParser: - return StructuredChatOutputParserWithRetries(llm=llm) - - @property - def _stop(self) -> List[str]: - return ["<|observation|>"] - - @classmethod - def create_prompt( - cls, - tools: Sequence[BaseTool], - prompt: str = None, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, - ) -> BasePromptTemplate: - tools_json = [] - tool_names = [] - for tool in tools: - tool_schema = model_schema(tool.args_schema) if tool.args_schema else {} - simplified_config_langchain = { - "name": tool.name, - "description": tool.description, - "parameters": tool_schema.get("properties", {}) - } - tools_json.append(simplified_config_langchain) - tool_names.append(tool.name) - formatted_tools = "\n".join([ - f"{tool['name']}: {tool['description']}, args: {tool['parameters']}" - for tool in tools_json - ]) - formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}") - template = prompt.format(tool_names=tool_names, - tools=formatted_tools, - input="{input}", - agent_scratchpad="{agent_scratchpad}") - - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - _memory_prompts = memory_prompts or [] - messages = [ - SystemMessagePromptTemplate.from_template(template), - *_memory_prompts, ] - return ChatPromptTemplate(input_variables=input_variables, messages=messages) - @classmethod - def from_llm_and_tools( - cls, - llm: BaseLanguageModel, - tools: Sequence[BaseTool], - prompt: str = None, - callbacks: List[BaseCallbackHandler] = [], - output_parser: Optional[AgentOutputParser] = None, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, - **kwargs: Any, - ) -> Agent: - """Construct an agent from an LLM and tools.""" + ).partial(tools=tools) - cls._validate_tools(tools) - prompt = cls.create_prompt( - tools, - prompt=prompt, - input_variables=input_variables, - memory_prompts=memory_prompts, - ) - llm_chain = LLMChain( - llm=llm, - prompt=prompt, - callbacks=callbacks, - verbose=True - ) - tool_names = [tool.name for tool in tools] - _output_parser = output_parser or cls._get_default_output_parser(llm=llm) - return cls( - llm_chain=llm_chain, - allowed_tools=tool_names, - output_parser=_output_parser, - **kwargs, - ) - - @property - def _agent_type(self) -> str: - raise ValueError - - -def initialize_glm3_agent( - tools: Sequence[BaseTool], - llm: BaseLanguageModel, - prompt: str = None, - callbacks: List[BaseCallbackHandler] = [], - memory: Optional[ConversationBufferWindowMemory] = None, - agent_kwargs: Optional[dict] = None, - *, - tags: Optional[Sequence[str]] = None, - **kwargs: Any, -) -> AgentExecutor: - tags_ = list(tags) if tags else [] - agent_kwargs = agent_kwargs or {} - llm.callbacks = callbacks - agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools( - llm=llm, - tools=tools, - prompt=prompt, - **agent_kwargs - ) - return AgentExecutor.from_agent_and_tools( - agent=agent_obj, - tools=tools, - callbacks=callbacks, - memory=memory, - tags=tags_, - **kwargs, + llm_with_stop = llm.bind(stop=["<|observation|>"]) + agent = ( + RunnablePassthrough.assign( + agent_scratchpad=lambda x: x["intermediate_steps"], + ) + | prompt + | llm_with_stop + | StructuredGLM3ChatOutputParser() ) + return agent diff --git a/server/agent/tools_factory/search_internet.py b/server/agent/tools_factory/search_internet.py index 68a5c6b1..fb36ecf3 100644 --- a/server/agent/tools_factory/search_internet.py +++ b/server/agent/tools_factory/search_internet.py @@ -95,5 +95,5 @@ def search_internet(query: str): return search_engine(query=query, config=tool_config) class SearchInternetInput(BaseModel): - query: str = Field(description="Query for Internet search") + query: str = Field(description="query for Internet search") diff --git a/server/callback_handler/agent_callback_handler.py b/server/callback_handler/agent_callback_handler.py index 4b7f0b97..6e845681 100644 --- a/server/callback_handler/agent_callback_handler.py +++ b/server/callback_handler/agent_callback_handler.py @@ -102,7 +102,8 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - print("tool_start") + pass + async def on_tool_end( self, diff --git a/server/chat/chat.py b/server/chat/chat.py index a9f6c60b..ca6295e1 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,19 +1,16 @@ import asyncio import json -from typing import List, Union, AsyncIterable, Dict +from typing import AsyncIterable, List, Union, Dict from fastapi import Body from fastapi.responses import StreamingResponse - -from langchain.agents import initialize_agent, AgentType, create_structured_chat_agent, AgentExecutor -from langchain_core.messages import HumanMessage, AIMessage from langchain_core.output_parsers import StrOutputParser +from langchain_core.messages import AIMessage, HumanMessage + from langchain.chains import LLMChain from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts import PromptTemplate -from langchain_core.runnables import RunnableBranch - -from server.agent.agent_factory import initialize_glm3_agent, initialize_qwen_agent +from server.agent.agent_factory.agents_registry import agents_registry from server.agent.tools_factory.tools_registry import all_tools from server.agent.container import container @@ -21,7 +18,7 @@ from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template from server.chat.utils import History from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.db.repository import add_message_to_db -from server.callback_handler.agent_callback_handler import AgentStatus, AgentExecutorAsyncIteratorCallbackHandler +from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler def create_models_from_config(configs, callbacks, stream): @@ -46,7 +43,6 @@ def create_models_from_config(configs, callbacks, stream): return models, prompts -# 在这里写构建逻辑 def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata): memory = None chat_prompt = None @@ -78,38 +74,21 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks | models["preprocess_model"] | StrOutputParser() ) - if "action_model" in models and len(tools) > 0: - if "chatglm3" in models["action_model"].model_name.lower(): - agent_executor = initialize_glm3_agent( - llm=models["action_model"], - tools=tools, - prompt=prompts["action_model"], - callbacks=callbacks, - verbose=True, - ) - elif "qwen" in models["action_model"].model_name.lower(): - agent_executor = initialize_qwen_agent( - llm=models["action_model"], - tools=tools, - prompt=prompts["action_model"], - callbacks=callbacks, - verbose=True, - ) - else: - agent_executor = initialize_agent( - llm=models["action_model"], - tools=tools, - callbacks=callbacks, - agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - ) + if "action_model" in models and tools is not None: + agent_executor = agents_registry( + llm=models["action_model"], + callbacks=callbacks, + tools=tools, + prompt=None, + verbose=True + ) # branch = RunnableBranch( # (lambda x: "1" in x["topic"].lower(), agent_executor), # chain # ) # full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) - full_chain = ({"input": lambda x: x["input"], } | agent_executor) + full_chain = ({"input": lambda x: x["input"]} | agent_executor) else: chain.llm.callbacks = callbacks full_chain = ({"input": lambda x: x["input"]} | chain) @@ -155,7 +134,17 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 history=history, history_len=history_len, metadata=metadata) - task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}), callback.done)) + task = asyncio.create_task(wrap_done( + full_chain.ainvoke( + { + "input": query, + "chat_history": [ + HumanMessage(content="今天北京的温度是多少度"), + AIMessage(content="今天北京的温度是1度"), + ], + + } + ),callback.done)) async for chunk in callback.aiter(): data = json.loads(chunk)