mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-02 04:23:13 +08:00
按照 langchain 0.1 重写 qwen agent
This commit is contained in:
parent
df1a508e10
commit
7257521e10
@ -1,2 +1,2 @@
|
||||
from .glm3_agent import create_structured_glm3_chat_agent
|
||||
from .qwen_agent import initialize_qwen_agent
|
||||
from .qwen_agent import create_structured_qwen_chat_agent
|
||||
|
||||
@ -8,7 +8,8 @@ from langchain_core.messages import SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from server.agent.agent_factory import create_structured_glm3_chat_agent
|
||||
from server.agent.agent_factory import (create_structured_glm3_chat_agent,
|
||||
create_structured_qwen_chat_agent)
|
||||
|
||||
|
||||
def agents_registry(
|
||||
@ -22,10 +23,14 @@ def agents_registry(
|
||||
else:
|
||||
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
|
||||
|
||||
# llm.callbacks = callbacks
|
||||
|
||||
# Write any optimized method here.
|
||||
if "glm3" in llm.model_name.lower():
|
||||
# An optimized method of langchain Agent that uses the glm3 series model
|
||||
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
|
||||
elif "qwen" in llm.model_name.lower():
|
||||
agent = create_structured_qwen_chat_agent(llm=llm, tools=tools)
|
||||
else:
|
||||
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
|
||||
|
||||
|
||||
@ -1,26 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||
import re
|
||||
from langchain.agents import Tool
|
||||
from langchain.agents.agent import LLMSingleActionAgent, AgentOutputParser
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate, BaseChatPromptTemplate
|
||||
import json
|
||||
import logging
|
||||
from langchain.pydantic_v1 import Field
|
||||
from operator import itemgetter
|
||||
import re
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.prompts.chat import BaseChatPromptTemplate
|
||||
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
|
||||
HumanMessage, SystemMessage, AIMessage)
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.render import format_tool_to_openai_function
|
||||
|
||||
from server.utils import get_prompt_template
|
||||
|
||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -28,7 +26,7 @@ class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
|
||||
# The template to use
|
||||
template: str
|
||||
# The list of tools available
|
||||
tools: List[Tool]
|
||||
tools: List[BaseTool]
|
||||
|
||||
def format_messages(self, **kwargs) -> str:
|
||||
# Get the intermediate steps (AgentAction, Observation tuples)
|
||||
@ -52,8 +50,8 @@ class QwenChatAgentPromptTemplate(BaseChatPromptTemplate):
|
||||
return [HumanMessage(content=formatted)]
|
||||
|
||||
|
||||
class QwenChatAgentOutputParser(StructuredChatOutputParser):
|
||||
"""Output parser with retries for the structured chat agent."""
|
||||
class QwenChatAgentOutputParserCustom(StructuredChatOutputParser):
|
||||
"""Output parser with retries for the structured chat agent with custom qwen prompt."""
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
if s := re.findall(r"\nAction:\s*(.+)\nAction\sInput:\s*(.+)", text, flags=re.DOTALL):
|
||||
@ -67,111 +65,51 @@ class QwenChatAgentOutputParser(StructuredChatOutputParser):
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "structured_chat_qwen_with_retries"
|
||||
return "StructuredQWenChatOutputParserCustom"
|
||||
|
||||
|
||||
class QwenChatAgent(LLMSingleActionAgent):
|
||||
"""Structured Chat Agent."""
|
||||
class QwenChatAgentOutputParserLC(StructuredChatOutputParser):
|
||||
"""Output parser with retries for the structured chat agent with standard lc prompt."""
|
||||
|
||||
output_parser: AgentOutputParser = Field(
|
||||
default_factory=QwenChatAgentOutputParser
|
||||
)
|
||||
"""Output parser for the agent."""
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
if s := re.findall(r"\nAction:\s*```(.+)```", text, flags=re.DOTALL):
|
||||
action = json.loads(s[0])
|
||||
tool = action.get("action")
|
||||
if tool == "Final Answer":
|
||||
return AgentFinish({"output": action.get("action_input", "")}, log=text)
|
||||
else:
|
||||
return AgentAction(tool=tool, tool_input=action.get("action_input", {}), log=text)
|
||||
else:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
"""Prefix to append the qwen observation with."""
|
||||
return "\nObservation:"
|
||||
def _type(self) -> str:
|
||||
return "StructuredQWenChatOutputParserLC"
|
||||
|
||||
@property
|
||||
def llm_prefix(self) -> str:
|
||||
"""Prefix to append the llm call with."""
|
||||
return "\nThought:"
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[QwenChatAgentPromptTemplate]] = None,
|
||||
) -> QwenChatAgentPromptTemplate:
|
||||
template = get_prompt_template("action_model", "qwen")
|
||||
return QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
|
||||
template=template,
|
||||
tools=tools)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
def create_structured_qwen_chat_agent(
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
callbacks: List[BaseCallbackHandler] = [],
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> QwenChatAgent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prompt=prompt,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
output_parser = output_parser or QwenChatAgentOutputParser()
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=output_parser,
|
||||
stop=["\nObservation:"],
|
||||
**kwargs,
|
||||
)
|
||||
use_custom_prompt: bool = True,
|
||||
) -> Runnable:
|
||||
if use_custom_prompt:
|
||||
prompt = "qwen"
|
||||
output_parser = QwenChatAgentOutputParserCustom()
|
||||
else:
|
||||
prompt = "structured-chat-agent"
|
||||
output_parser = QwenChatAgentOutputParserLC()
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
return "qwen_chat_agent"
|
||||
template = get_prompt_template("action_model", prompt)
|
||||
prompt = QwenChatAgentPromptTemplate(input_variables=["input", "intermediate_steps"],
|
||||
template=template,
|
||||
tools=tools)
|
||||
|
||||
|
||||
def initialize_qwen_agent(
|
||||
tools: Sequence[BaseTool],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: str = None,
|
||||
callbacks: List[BaseCallbackHandler] = [],
|
||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||
agent_kwargs: Optional[dict] = None,
|
||||
*,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
tags_ = list(tags) if tags else []
|
||||
agent_kwargs = agent_kwargs or {}
|
||||
llm.callbacks=callbacks
|
||||
agent_obj = QwenChatAgent.from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt,
|
||||
**agent_kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
memory=memory,
|
||||
tags=tags_,
|
||||
intermediate_steps=[],
|
||||
**kwargs,
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=itemgetter("intermediate_steps")
|
||||
)
|
||||
| prompt
|
||||
| llm.bind(stop="\nObservation:")
|
||||
| output_parser
|
||||
)
|
||||
return agent
|
||||
|
||||
@ -21,6 +21,11 @@ from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||
|
||||
|
||||
# from langchain.globals import set_debug, set_verbose
|
||||
# set_debug(True)
|
||||
# set_verbose(True)
|
||||
|
||||
|
||||
def create_models_from_config(configs, callbacks, stream):
|
||||
if configs is None:
|
||||
configs = {}
|
||||
|
||||
@ -3,31 +3,39 @@ from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import asyncio
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
# from langchain.chat_models.openai import ChatOpenAI
|
||||
from server.utils import get_ChatOpenAI
|
||||
from server.agent.tools_factory.tools_registry import all_tools
|
||||
from server.agent.agent_factory.qwen_agent import initialize_qwen_agent
|
||||
from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_agent
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||
from langchain import globals
|
||||
|
||||
# globals.set_debug(True)
|
||||
globals.set_debug(True)
|
||||
# globals.set_verbose(True)
|
||||
|
||||
|
||||
async def main():
|
||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
|
||||
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
|
||||
executor = initialize_qwen_agent(tools=tools,
|
||||
llm=qwen_model,
|
||||
callbacks=[callback],
|
||||
)
|
||||
# qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
|
||||
qwen_model = ChatOpenAI(base_url="http://127.0.0.1:9997/v1",
|
||||
api_key="empty",
|
||||
streaming=False,
|
||||
temperature=0.01,
|
||||
model="qwen",
|
||||
callbacks=[callback],
|
||||
)
|
||||
agent = create_structured_qwen_chat_agent(tools=tools, llm=qwen_model)
|
||||
executor = AgentExecutor(agent=agent, tools=tools, verbose=True, callbacks=[callback])
|
||||
|
||||
# ret = executor.invoke("苏州今天冷吗")
|
||||
ret = asyncio.create_task(executor.ainvoke("苏州今天冷吗"))
|
||||
ret = asyncio.create_task(executor.ainvoke({"input": "苏州今天冷吗"}))
|
||||
async for chunk in callback.aiter():
|
||||
print(chunk)
|
||||
# ret = executor.invoke("从知识库samples中查询chatchat项目简介")
|
||||
# ret = executor.invoke("chatchat项目主要issue有哪些")
|
||||
print(ret)
|
||||
await ret
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user