diff --git a/server/callback_handler/agent_callback_handler.py b/server/callback_handler/agent_callback_handler.py index b477ad2e..4b7f0b97 100644 --- a/server/callback_handler/agent_callback_handler.py +++ b/server/callback_handler/agent_callback_handler.py @@ -19,7 +19,7 @@ class AgentStatus: llm_end: int = 3 agent_action: int = 4 agent_finish: int = 5 - tool_begin: int = 6 + tool_start: int = 6 tool_end: int = 7 error: int = 8 @@ -59,10 +59,28 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): } self.queue.put_nowait(dumps(data)) + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + data = { + "status": AgentStatus.llm_start, + "text": "", + } + self.done.clear() + self.queue.put_nowait(dumps(data)) + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: data = { "status": AgentStatus.llm_end, - "text": "", + "text": response.generations[0][0].message.content, } self.queue.put_nowait(dumps(data)) @@ -84,7 +102,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: - print("tool_begin") + print("tool_start") async def on_tool_end( self, @@ -100,7 +118,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): "status": AgentStatus.tool_end, "tool_output": output, } - self.done.clear() + # self.done.clear() self.queue.put_nowait(dumps(data)) async def on_tool_error( @@ -117,7 +135,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): "status": AgentStatus.tool_end, "text": error, } - self.done.clear() + # self.done.clear() self.queue.put_nowait(dumps(data)) async def on_agent_action( @@ -142,7 +160,6 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: - if "Thought:" in finish.return_values["output"]: finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "") diff --git a/server/chat/chat.py b/server/chat/chat.py index b8e68be8..e2b6ae44 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -113,6 +113,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks # full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) full_chain = ({"input": lambda x: x["input"]} | agent_executor) else: + chain.llm.callbacks = callbacks full_chain = ({"input": lambda x: x["input"]} | chain) return full_chain diff --git a/tests/test_qwen_agent.py b/tests/test_qwen_agent.py index 7f0e1d0d..a5a3dc94 100644 --- a/tests/test_qwen_agent.py +++ b/tests/test_qwen_agent.py @@ -15,8 +15,9 @@ from langchain import globals async def main(): callback = AgentExecutorAsyncIteratorCallbackHandler() + tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools] qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback]) - executor = initialize_qwen_agent(tools=all_tools, + executor = initialize_qwen_agent(tools=tools, llm=qwen_model, callbacks=[callback], )