fix: LLMChain no output when no tools selected

This commit is contained in:
liunux4odoo 2023-12-23 23:09:08 +08:00
parent 36c90e2e2b
commit d144ff47c9
3 changed files with 26 additions and 7 deletions

View File

@ -19,7 +19,7 @@ class AgentStatus:
llm_end: int = 3 llm_end: int = 3
agent_action: int = 4 agent_action: int = 4
agent_finish: int = 5 agent_finish: int = 5
tool_begin: int = 6 tool_start: int = 6
tool_end: int = 7 tool_end: int = 7
error: int = 8 error: int = 8
@ -59,10 +59,28 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
} }
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
data = {
"status": AgentStatus.llm_start,
"text": "",
}
self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
data = { data = {
"status": AgentStatus.llm_end, "status": AgentStatus.llm_end,
"text": "", "text": response.generations[0][0].message.content,
} }
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
@ -84,7 +102,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
print("tool_begin") print("tool_start")
async def on_tool_end( async def on_tool_end(
self, self,
@ -100,7 +118,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"status": AgentStatus.tool_end, "status": AgentStatus.tool_end,
"tool_output": output, "tool_output": output,
} }
self.done.clear() # self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
async def on_tool_error( async def on_tool_error(
@ -117,7 +135,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"status": AgentStatus.tool_end, "status": AgentStatus.tool_end,
"text": error, "text": error,
} }
self.done.clear() # self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
async def on_agent_action( async def on_agent_action(
@ -142,7 +160,6 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
if "Thought:" in finish.return_values["output"]: if "Thought:" in finish.return_values["output"]:
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "") finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")

View File

@ -113,6 +113,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
# full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) # full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
full_chain = ({"input": lambda x: x["input"]} | agent_executor) full_chain = ({"input": lambda x: x["input"]} | agent_executor)
else: else:
chain.llm.callbacks = callbacks
full_chain = ({"input": lambda x: x["input"]} | chain) full_chain = ({"input": lambda x: x["input"]} | chain)
return full_chain return full_chain

View File

@ -15,8 +15,9 @@ from langchain import globals
async def main(): async def main():
callback = AgentExecutorAsyncIteratorCallbackHandler() callback = AgentExecutorAsyncIteratorCallbackHandler()
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback]) qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
executor = initialize_qwen_agent(tools=all_tools, executor = initialize_qwen_agent(tools=tools,
llm=qwen_model, llm=qwen_model,
callbacks=[callback], callbacks=[callback],
) )