fix: LLMChain no output when no tools selected

This commit is contained in:
liunux4odoo 2023-12-23 23:09:08 +08:00
parent 36c90e2e2b
commit d144ff47c9
3 changed files with 26 additions and 7 deletions

View File

@ -19,7 +19,7 @@ class AgentStatus:
llm_end: int = 3
agent_action: int = 4
agent_finish: int = 5
tool_begin: int = 6
tool_start: int = 6
tool_end: int = 7
error: int = 8
@ -59,10 +59,28 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
}
self.queue.put_nowait(dumps(data))
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
data = {
"status": AgentStatus.llm_start,
"text": "",
}
self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
data = {
"status": AgentStatus.llm_end,
"text": "",
"text": response.generations[0][0].message.content,
}
self.queue.put_nowait(dumps(data))
@ -84,7 +102,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
print("tool_begin")
print("tool_start")
async def on_tool_end(
self,
@ -100,7 +118,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"status": AgentStatus.tool_end,
"tool_output": output,
}
self.done.clear()
# self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_tool_error(
@ -117,7 +135,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"status": AgentStatus.tool_end,
"text": error,
}
self.done.clear()
# self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_agent_action(
@ -142,7 +160,6 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
if "Thought:" in finish.return_values["output"]:
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")

View File

@ -113,6 +113,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
# full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
else:
chain.llm.callbacks = callbacks
full_chain = ({"input": lambda x: x["input"]} | chain)
return full_chain

View File

@ -15,8 +15,9 @@ from langchain import globals
async def main():
callback = AgentExecutorAsyncIteratorCallbackHandler()
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
executor = initialize_qwen_agent(tools=all_tools,
executor = initialize_qwen_agent(tools=tools,
llm=qwen_model,
callbacks=[callback],
)