diff --git a/server/agent/agent_factory/glm3_agent.py b/server/agent/agent_factory/glm3_agent.py index 52a16ae3..f9bf3aee 100644 --- a/server/agent/agent_factory/glm3_agent.py +++ b/server/agent/agent_factory/glm3_agent.py @@ -19,7 +19,7 @@ from langchain.output_parsers import OutputFixingParser from langchain.pydantic_v1 import Field from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate from langchain.agents.agent import AgentExecutor -from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.base import BaseCallbackHandler from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool from langchain_core.callbacks import Callbacks @@ -169,7 +169,7 @@ class StructuredGLM3ChatAgent(Agent): llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: str = None, - callback_manager: Optional[BaseCallbackManager] = None, + callbacks: List[BaseCallbackHandler] = [], output_parser: Optional[AgentOutputParser] = None, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, input_variables: Optional[List[str]] = None, @@ -187,7 +187,7 @@ class StructuredGLM3ChatAgent(Agent): llm_chain = LLMChain( llm=llm, prompt=prompt, - callback_manager=callback_manager, + callbacks=callbacks, verbose=True ) tool_names = [tool.name for tool in tools] @@ -208,6 +208,7 @@ def initialize_glm3_agent( tools: Sequence[BaseTool], llm: BaseLanguageModel, prompt: str = None, + callbacks: List[BaseCallbackHandler] = [], memory: Optional[ConversationBufferWindowMemory] = None, agent_kwargs: Optional[dict] = None, *, @@ -216,6 +217,7 @@ def initialize_glm3_agent( ) -> AgentExecutor: tags_ = list(tags) if tags else [] agent_kwargs = agent_kwargs or {} + llm.callbacks=callbacks agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools( llm=llm, tools=tools, @@ -225,6 +227,7 @@ def initialize_glm3_agent( return AgentExecutor.from_agent_and_tools( agent=agent_obj, tools=tools, + callbacks=callbacks, memory=memory, tags=tags_, intermediate_steps=[], diff --git a/server/agent/agent_factory/qwen_agent.py b/server/agent/agent_factory/qwen_agent.py index 5b081c53..f31f1ae5 100644 --- a/server/agent/agent_factory/qwen_agent.py +++ b/server/agent/agent_factory/qwen_agent.py @@ -14,7 +14,7 @@ from langchain.pydantic_v1 import Field from langchain.schema import (AgentAction, AgentFinish, OutputParserException, HumanMessage, SystemMessage, AIMessage) from langchain.agents.agent import AgentExecutor -from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.base import BaseCallbackHandler from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.tools.render import format_tool_to_openai_function @@ -107,16 +107,16 @@ class QwenChatAgent(LLMSingleActionAgent): @classmethod def from_llm_and_tools( - cls, - llm: BaseLanguageModel, - tools: Sequence[BaseTool], - prompt: str = None, - callback_manager: Optional[BaseCallbackManager] = None, - output_parser: Optional[AgentOutputParser] = None, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BaseChatPromptTemplate]] = None, - **kwargs: Any, + cls, + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + prompt: str = None, + callbacks: List[BaseCallbackHandler] = [], + output_parser: Optional[AgentOutputParser] = None, + human_message_template: str = HUMAN_MESSAGE_TEMPLATE, + input_variables: Optional[List[str]] = None, + memory_prompts: Optional[List[BaseChatPromptTemplate]] = None, + **kwargs: Any, ) -> QwenChatAgent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) @@ -129,7 +129,7 @@ class QwenChatAgent(LLMSingleActionAgent): llm_chain = LLMChain( llm=llm, prompt=prompt, - callback_manager=callback_manager, + callbacks=callbacks, ) tool_names = [tool.name for tool in tools] output_parser = output_parser or QwenChatAgentOutputParser() @@ -147,34 +147,33 @@ class QwenChatAgent(LLMSingleActionAgent): def initialize_qwen_agent( - tools: Sequence[BaseTool], - llm: BaseLanguageModel, - prompt: str = None, - callback_manager: Optional[BaseCallbackManager] = None, - memory: Optional[ConversationBufferWindowMemory] = None, - agent_kwargs: Optional[dict] = None, - *, - return_direct: Optional[bool] = None, - tags: Optional[Sequence[str]] = None, - **kwargs: Any, + tools: Sequence[BaseTool], + llm: BaseLanguageModel, + prompt: str = None, + callbacks: List[BaseCallbackHandler] = [], + memory: Optional[ConversationBufferWindowMemory] = None, + agent_kwargs: Optional[dict] = None, + *, + return_direct: Optional[bool] = None, + tags: Optional[Sequence[str]] = None, + **kwargs: Any, ) -> AgentExecutor: tags_ = list(tags) if tags else [] agent_kwargs = agent_kwargs or {} if isinstance(return_direct, bool): # can make all tools return directly tools = [t.copy(update={"return_direct": return_direct}) for t in tools] - + llm.callbacks=callbacks agent_obj = QwenChatAgent.from_llm_and_tools( llm=llm, tools=tools, prompt=prompt, - callback_manager=callback_manager, - **agent_kwargs + **agent_kwargs, ) return AgentExecutor.from_agent_and_tools( agent=agent_obj, tools=tools, - callback_manager=callback_manager, + callbacks=callbacks, memory=memory, tags=tags_, intermediate_steps=[], diff --git a/server/callback_handler/agent_callback_handler.py b/server/callback_handler/agent_callback_handler.py index 655bf6e5..23949b5e 100644 --- a/server/callback_handler/agent_callback_handler.py +++ b/server/callback_handler/agent_callback_handler.py @@ -13,83 +13,50 @@ def dumps(obj: Dict) -> str: return json.dumps(obj, ensure_ascii=False) -class Status: - start: int = 1 - running: int = 2 - complete: int = 3 +class AgentStatus: + llm_start: int = 1 + llm_new_token: int = 2 + llm_end: int = 3 agent_action: int = 4 agent_finish: int = 5 error: int = 6 - tool_finish: int = 7 -class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): +class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): def __init__(self): super().__init__() self.queue = asyncio.Queue() self.done = asyncio.Event() - self.cur_tool = {} self.out = True - async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, - parent_run_id: UUID | None = None, tags: List[str] | None = None, - metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None: - self.cur_tool = { - "tool_name": serialized["name"], - "input_str": input_str, - "output_str": "", - "status": Status.agent_action, - "run_id": run_id.hex, - "llm_token": "", - "final_answer": "", - "error": "", + async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: + data = { + "status" : AgentStatus.llm_start, + "text" : "", } - self.queue.put_nowait(dumps(self.cur_tool)) + self.done.clear() + self.queue.put_nowait(dumps(data)) - async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, - tags: List[str] | None = None, **kwargs: Any) -> None: - - self.out = True - self.cur_tool.update( - status=Status.tool_finish, - output_str=output.replace("Answer:", ""), - ) - self.queue.put_nowait(dumps(self.cur_tool)) - - async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID, - parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: - self.cur_tool.update( - status=Status.error, - error=str(error), - ) - self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: special_tokens = ["Action", "<|observation|>"] for stoken in special_tokens: if stoken in token: before_action = token.split(stoken)[0] - self.cur_tool.update( - status=Status.running, - llm_token=before_action + "\n", - ) - self.queue.put_nowait(dumps(self.cur_tool)) + data = { + "status" : AgentStatus.llm_new_token, + "text": before_action + "\n", + } + self.queue.put_nowait(dumps(data)) self.out = False break if token is not None and token != "" and self.out: - self.cur_tool.update( - status=Status.running, - llm_token=token, - ) - self.queue.put_nowait(dumps(self.cur_tool)) - - async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: - self.cur_tool.update( - status=Status.start, - llm_token="", - ) - self.queue.put_nowait(dumps(self.cur_tool)) + data = { + "status" : AgentStatus.llm_new_token, + "text" : token, + } + self.queue.put_nowait(dumps(data)) async def on_chat_model_start( self, @@ -102,26 +69,27 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - self.cur_tool.update( - status=Status.start, - llm_token="", - ) - self.queue.put_nowait(dumps(self.cur_tool)) + data = { + "status" : AgentStatus.llm_start, + "text" : "", + } + self.done.clear() + self.queue.put_nowait(dumps(data)) async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - self.cur_tool.update( - status=Status.complete, - llm_token="", - ) + data = { + "status" : AgentStatus.llm_end, + "text" : response.generations[0][0].message.content, + } self.out = True - self.queue.put_nowait(dumps(self.cur_tool)) + self.queue.put_nowait(dumps(data)) async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None: - self.cur_tool.update( - status=Status.error, - error=str(error), - ) - self.queue.put_nowait(dumps(self.cur_tool)) + data = { + "status" : AgentStatus.error, + "text" : str(error), + } + self.queue.put_nowait(dumps(data)) async def on_agent_action( self, @@ -132,12 +100,13 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - self.cur_tool.update( - status=Status.agent_action, - tool_name=action.tool, - tool_input=action.tool_input, - ) - self.queue.put_nowait(dumps(self.cur_tool)) + data = { + "status" : AgentStatus.agent_action, + "tool_name" : action.tool, + "tool_input" : action.tool_input, + "text": action.log, + } + self.queue.put_nowait(dumps(data)) async def on_agent_finish( self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -147,8 +116,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): if "Thought:" in finish.return_values["output"]: finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "") - self.cur_tool.update( - status=Status.agent_finish, - agent_finish=finish.return_values["output"], - ) - self.queue.put_nowait(dumps(self.cur_tool)) + data = { + "status" : AgentStatus.agent_finish, + "text" : finish.return_values["output"], + } + self.done.set() + self.queue.put_nowait(dumps(data)) diff --git a/server/chat/chat.py b/server/chat/chat.py index 23a7624c..74933ba2 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -20,22 +20,23 @@ from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template from server.chat.utils import History from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.db.repository import add_message_to_db -from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler +from server.callback_handler.agent_callback_handler import AgentStatus, AgentExecutorAsyncIteratorCallbackHandler -def create_models_from_config(configs, callbacks): +def create_models_from_config(configs, callbacks, stream): if configs is None: configs = {} models = {} prompts = {} for model_type, model_configs in configs.items(): for model_name, params in model_configs.items(): - callback = callbacks if params.get('callbacks', False) else None + callbacks = callbacks if params.get('callbacks', False) else None model_instance = get_ChatOpenAI( model_name=model_name, temperature=params.get('temperature', 0.5), max_tokens=params.get('max_tokens', 1000), - callbacks=callback + callbacks=callbacks, + streaming=stream, ) models[model_type] = model_instance prompt_name = params.get('prompt_name', 'default') @@ -83,7 +84,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks tools=tools, prompt=prompts["action_model"], memory=memory, - # callback_manager=BaseCallbackManager(handlers=callbacks), + callbacks=callbacks, verbose=True, ) elif "qwen" in models["action_model"].model_name.lower(): @@ -92,7 +93,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks tools=tools, prompt=prompts["action_model"], memory=memory, - # callback_manager=BaseCallbackManager(handlers=callbacks), + callbacks=callbacks, verbose=True, ) else: @@ -131,7 +132,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 ] ] ), - stream: bool = Body(False, description="流式输出"), + stream: bool = Body(True, description="流式输出"), model_config: Dict = Body({}, description="LLM 模型配置"), tool_config: Dict = Body({}, description="工具配置"), ): @@ -142,11 +143,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 conversation_id=conversation_id ) if conversation_id else None - callback = CustomAsyncIteratorCallbackHandler() + callback = AgentExecutorAsyncIteratorCallbackHandler() callbacks = [callback] # 从配置中选择模型 - models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config) + models, prompts = create_models_from_config(callbacks=[], configs=model_config, stream=stream) # 从配置中选择工具 tools = [tool for tool in all_tools if tool.name in tool_config] @@ -164,56 +165,13 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 # Execute Chain task = asyncio.create_task( - wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done)) - if stream: - async for chunk in callback.aiter(): - data = json.loads(chunk) - if data["status"] == Status.start: - continue - elif data["status"] == Status.agent_action: - tool_info = { - "tool_name": data["tool_name"], - "tool_input": data["tool_input"] - } - yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False) - elif data["status"] == Status.agent_finish: - yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id}, - ensure_ascii=False) - else: - yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False) - else: - text = "" - agent_finish = "" - tool_info = None - async for chunk in callback.aiter(): - data = json.loads(chunk) - if data["status"] == Status.agent_action: - tool_info = { - "tool_name": data["tool_name"], - "tool_input": data["tool_input"] - } - if data["status"] == Status.agent_finish: - agent_finish = data["agent_finish"] - else: - text += data["llm_token"] - if tool_info: - yield json.dumps( - { - "text": text, - "agent_action": tool_info, - "agent_finish": agent_finish, - "message_id": message_id - }, - ensure_ascii=False - ) - else: - yield json.dumps( - { - "text": text, - "message_id": message_id - }, - ensure_ascii=False - ) + wrap_done(full_chain.ainvoke({"input": query}), callback.done)) + + async for chunk in callback.aiter(): + data = json.loads(chunk) + data["message_id"] = message_id + yield json.dumps(data, ensure_ascii=False) + await task return EventSourceResponse(chat_iterator()) diff --git a/tests/test_qwen_agent.py b/tests/test_qwen_agent.py index 057d76ac..7f0e1d0d 100644 --- a/tests/test_qwen_agent.py +++ b/tests/test_qwen_agent.py @@ -2,19 +2,31 @@ import sys from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) +import asyncio from server.utils import get_ChatOpenAI from server.agent.tools_factory.tools_registry import all_tools from server.agent.agent_factory.qwen_agent import initialize_qwen_agent +from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler from langchain import globals -globals.set_debug(True) -globals.set_verbose(True) +# globals.set_debug(True) +# globals.set_verbose(True) -qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=False) -executor = initialize_qwen_agent(tools=all_tools, llm=qwen_model) +async def main(): + callback = AgentExecutorAsyncIteratorCallbackHandler() + qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback]) + executor = initialize_qwen_agent(tools=all_tools, + llm=qwen_model, + callbacks=[callback], + ) -# ret = executor.invoke("苏州今天冷吗") -ret = executor.invoke("从知识库samples中查询chatchat项目简介") -# ret = executor.invoke("chatchat项目主要issue有哪些") -print(ret) + # ret = executor.invoke("苏州今天冷吗") + ret = asyncio.create_task(executor.ainvoke("苏州今天冷吗")) + async for chunk in callback.aiter(): + print(chunk) + # ret = executor.invoke("从知识库samples中查询chatchat项目简介") + # ret = executor.invoke("chatchat项目主要issue有哪些") + print(ret) + +asyncio.run(main()) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 4e48efdc..9ed2df7f 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -11,6 +11,7 @@ import os import re import time from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG) +from server.callback_handler.agent_callback_handler import AgentStatus import uuid from typing import List, Dict @@ -271,10 +272,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): f'', unsafe_allow_html=True) - chat_box.ai_say("正在思考...") + chat_box.ai_say(["正在思考...", + Markdown(title="tool call", in_expander=True, expanded=True,state="running"), + Markdown()]) text = "" message_id = "" - element_index = 0 + tool_called = False for d in api.chat_chat(query=prompt, metadata=files_upload, @@ -283,33 +286,33 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): conversation_id=conversation_id, tool_config=selected_tool_configs, ): - try: - d = json.loads(d) - except: - pass message_id = d.get("message_id", "") metadata = { "message_id": message_id, } - if error_msg := check_error_msg(d): - st.error(error_msg) - if chunk := d.get("agent_action"): - chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete")) - element_index = 1 + if not tool_called: # 避免工具调用之后重复输出,将LLM输出分为工具调用前后分别处理 + element_index = 0 + else: + element_index = -1 + if d["status"] == AgentStatus.error: + st.error(d["text"]) + elif d["status"] == AgentStatus.agent_action: formatted_data = { - "action": chunk["tool_name"], - "action_input": chunk["tool_input"] + "action": d["tool_name"], + "action_input": d["tool_input"] } formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False) - text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n" - chat_box.update_msg(text, element_index=element_index, metadata=metadata) - if chunk := d.get("text"): - text += chunk - chat_box.update_msg(text, element_index=element_index, metadata=metadata) - if chunk := d.get("agent_finish"): - element_index = 0 - text = chunk - chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) + chat_box.update_msg(Markdown(f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"), element_index=1) + tool_called = True + text = "" + elif d["status"] == AgentStatus.llm_new_token: + text += d["text"] + chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata) + elif d["status"] == AgentStatus.llm_end: + chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) + elif d["status"] == AgentStatus.agent_finish: + chat_box.update_msg(element_index=1, state="complete", expanded=False) + chat_box.update_msg(Markdown(d["text"]), streaming=False, element_index=-1) if os.path.exists("tmp/image.jpg"): with open("tmp/image.jpg", "rb") as image_file: