fix callback handler

This commit is contained in:
liunux4odoo 2023-12-20 22:42:40 +08:00
parent 6f04e15aed
commit e2a46a1d0f
6 changed files with 143 additions and 198 deletions

View File

@ -19,7 +19,7 @@ from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
@ -169,7 +169,7 @@ class StructuredGLM3ChatAgent(Agent):
llm: BaseLanguageModel, llm: BaseLanguageModel,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
prompt: str = None, prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None, callbacks: List[BaseCallbackHandler] = [],
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE, human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
input_variables: Optional[List[str]] = None, input_variables: Optional[List[str]] = None,
@ -187,7 +187,7 @@ class StructuredGLM3ChatAgent(Agent):
llm_chain = LLMChain( llm_chain = LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
callback_manager=callback_manager, callbacks=callbacks,
verbose=True verbose=True
) )
tool_names = [tool.name for tool in tools] tool_names = [tool.name for tool in tools]
@ -208,6 +208,7 @@ def initialize_glm3_agent(
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: str = None, prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
memory: Optional[ConversationBufferWindowMemory] = None, memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None, agent_kwargs: Optional[dict] = None,
*, *,
@ -216,6 +217,7 @@ def initialize_glm3_agent(
) -> AgentExecutor: ) -> AgentExecutor:
tags_ = list(tags) if tags else [] tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {} agent_kwargs = agent_kwargs or {}
llm.callbacks=callbacks
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools( agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
llm=llm, llm=llm,
tools=tools, tools=tools,
@ -225,6 +227,7 @@ def initialize_glm3_agent(
return AgentExecutor.from_agent_and_tools( return AgentExecutor.from_agent_and_tools(
agent=agent_obj, agent=agent_obj,
tools=tools, tools=tools,
callbacks=callbacks,
memory=memory, memory=memory,
tags=tags_, tags=tags_,
intermediate_steps=[], intermediate_steps=[],

View File

@ -14,7 +14,7 @@ from langchain.pydantic_v1 import Field
from langchain.schema import (AgentAction, AgentFinish, OutputParserException, from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
HumanMessage, SystemMessage, AIMessage) HumanMessage, SystemMessage, AIMessage)
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_function from langchain.tools.render import format_tool_to_openai_function
@ -107,16 +107,16 @@ class QwenChatAgent(LLMSingleActionAgent):
@classmethod @classmethod
def from_llm_and_tools( def from_llm_and_tools(
cls, cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
prompt: str = None, prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None, callbacks: List[BaseCallbackHandler] = [],
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE, human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
input_variables: Optional[List[str]] = None, input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None, memory_prompts: Optional[List[BaseChatPromptTemplate]] = None,
**kwargs: Any, **kwargs: Any,
) -> QwenChatAgent: ) -> QwenChatAgent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
@ -129,7 +129,7 @@ class QwenChatAgent(LLMSingleActionAgent):
llm_chain = LLMChain( llm_chain = LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
callback_manager=callback_manager, callbacks=callbacks,
) )
tool_names = [tool.name for tool in tools] tool_names = [tool.name for tool in tools]
output_parser = output_parser or QwenChatAgentOutputParser() output_parser = output_parser or QwenChatAgentOutputParser()
@ -147,34 +147,33 @@ class QwenChatAgent(LLMSingleActionAgent):
def initialize_qwen_agent( def initialize_qwen_agent(
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: str = None, prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None, callbacks: List[BaseCallbackHandler] = [],
memory: Optional[ConversationBufferWindowMemory] = None, memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None, agent_kwargs: Optional[dict] = None,
*, *,
return_direct: Optional[bool] = None, return_direct: Optional[bool] = None,
tags: Optional[Sequence[str]] = None, tags: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AgentExecutor: ) -> AgentExecutor:
tags_ = list(tags) if tags else [] tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {} agent_kwargs = agent_kwargs or {}
if isinstance(return_direct, bool): # can make all tools return directly if isinstance(return_direct, bool): # can make all tools return directly
tools = [t.copy(update={"return_direct": return_direct}) for t in tools] tools = [t.copy(update={"return_direct": return_direct}) for t in tools]
llm.callbacks=callbacks
agent_obj = QwenChatAgent.from_llm_and_tools( agent_obj = QwenChatAgent.from_llm_and_tools(
llm=llm, llm=llm,
tools=tools, tools=tools,
prompt=prompt, prompt=prompt,
callback_manager=callback_manager, **agent_kwargs,
**agent_kwargs
) )
return AgentExecutor.from_agent_and_tools( return AgentExecutor.from_agent_and_tools(
agent=agent_obj, agent=agent_obj,
tools=tools, tools=tools,
callback_manager=callback_manager, callbacks=callbacks,
memory=memory, memory=memory,
tags=tags_, tags=tags_,
intermediate_steps=[], intermediate_steps=[],

View File

@ -13,83 +13,50 @@ def dumps(obj: Dict) -> str:
return json.dumps(obj, ensure_ascii=False) return json.dumps(obj, ensure_ascii=False)
class Status: class AgentStatus:
start: int = 1 llm_start: int = 1
running: int = 2 llm_new_token: int = 2
complete: int = 3 llm_end: int = 3
agent_action: int = 4 agent_action: int = 4
agent_finish: int = 5 agent_finish: int = 5
error: int = 6 error: int = 6
tool_finish: int = 7
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.done = asyncio.Event() self.done = asyncio.Event()
self.cur_tool = {}
self.out = True self.out = True
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
parent_run_id: UUID | None = None, tags: List[str] | None = None, data = {
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None: "status" : AgentStatus.llm_start,
self.cur_tool = { "text" : "",
"tool_name": serialized["name"],
"input_str": input_str,
"output_str": "",
"status": Status.agent_action,
"run_id": run_id.hex,
"llm_token": "",
"final_answer": "",
"error": "",
} }
self.queue.put_nowait(dumps(self.cur_tool)) self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
tags: List[str] | None = None, **kwargs: Any) -> None:
self.out = True
self.cur_tool.update(
status=Status.tool_finish,
output_str=output.replace("Answer:", ""),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.error,
error=str(error),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["Action", "<|observation|>"] special_tokens = ["Action", "<|observation|>"]
for stoken in special_tokens: for stoken in special_tokens:
if stoken in token: if stoken in token:
before_action = token.split(stoken)[0] before_action = token.split(stoken)[0]
self.cur_tool.update( data = {
status=Status.running, "status" : AgentStatus.llm_new_token,
llm_token=before_action + "\n", "text": before_action + "\n",
) }
self.queue.put_nowait(dumps(self.cur_tool)) self.queue.put_nowait(dumps(data))
self.out = False self.out = False
break break
if token is not None and token != "" and self.out: if token is not None and token != "" and self.out:
self.cur_tool.update( data = {
status=Status.running, "status" : AgentStatus.llm_new_token,
llm_token=token, "text" : token,
) }
self.queue.put_nowait(dumps(self.cur_tool)) self.queue.put_nowait(dumps(data))
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.start,
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_chat_model_start( async def on_chat_model_start(
self, self,
@ -102,26 +69,27 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
self.cur_tool.update( data = {
status=Status.start, "status" : AgentStatus.llm_start,
llm_token="", "text" : "",
) }
self.queue.put_nowait(dumps(self.cur_tool)) self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.cur_tool.update( data = {
status=Status.complete, "status" : AgentStatus.llm_end,
llm_token="", "text" : response.generations[0][0].message.content,
) }
self.out = True self.out = True
self.queue.put_nowait(dumps(self.cur_tool)) self.queue.put_nowait(dumps(data))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None: async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
self.cur_tool.update( data = {
status=Status.error, "status" : AgentStatus.error,
error=str(error), "text" : str(error),
) }
self.queue.put_nowait(dumps(self.cur_tool)) self.queue.put_nowait(dumps(data))
async def on_agent_action( async def on_agent_action(
self, self,
@ -132,12 +100,13 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
self.cur_tool.update( data = {
status=Status.agent_action, "status" : AgentStatus.agent_action,
tool_name=action.tool, "tool_name" : action.tool,
tool_input=action.tool_input, "tool_input" : action.tool_input,
) "text": action.log,
self.queue.put_nowait(dumps(self.cur_tool)) }
self.queue.put_nowait(dumps(data))
async def on_agent_finish( async def on_agent_finish(
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
@ -147,8 +116,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
if "Thought:" in finish.return_values["output"]: if "Thought:" in finish.return_values["output"]:
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "") finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
self.cur_tool.update( data = {
status=Status.agent_finish, "status" : AgentStatus.agent_finish,
agent_finish=finish.return_values["output"], "text" : finish.return_values["output"],
) }
self.queue.put_nowait(dumps(self.cur_tool)) self.done.set()
self.queue.put_nowait(dumps(data))

View File

@ -20,22 +20,23 @@ from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.chat.utils import History from server.chat.utils import History
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler from server.callback_handler.agent_callback_handler import AgentStatus, AgentExecutorAsyncIteratorCallbackHandler
def create_models_from_config(configs, callbacks): def create_models_from_config(configs, callbacks, stream):
if configs is None: if configs is None:
configs = {} configs = {}
models = {} models = {}
prompts = {} prompts = {}
for model_type, model_configs in configs.items(): for model_type, model_configs in configs.items():
for model_name, params in model_configs.items(): for model_name, params in model_configs.items():
callback = callbacks if params.get('callbacks', False) else None callbacks = callbacks if params.get('callbacks', False) else None
model_instance = get_ChatOpenAI( model_instance = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
temperature=params.get('temperature', 0.5), temperature=params.get('temperature', 0.5),
max_tokens=params.get('max_tokens', 1000), max_tokens=params.get('max_tokens', 1000),
callbacks=callback callbacks=callbacks,
streaming=stream,
) )
models[model_type] = model_instance models[model_type] = model_instance
prompt_name = params.get('prompt_name', 'default') prompt_name = params.get('prompt_name', 'default')
@ -83,7 +84,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
tools=tools, tools=tools,
prompt=prompts["action_model"], prompt=prompts["action_model"],
memory=memory, memory=memory,
# callback_manager=BaseCallbackManager(handlers=callbacks), callbacks=callbacks,
verbose=True, verbose=True,
) )
elif "qwen" in models["action_model"].model_name.lower(): elif "qwen" in models["action_model"].model_name.lower():
@ -92,7 +93,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
tools=tools, tools=tools,
prompt=prompts["action_model"], prompt=prompts["action_model"],
memory=memory, memory=memory,
# callback_manager=BaseCallbackManager(handlers=callbacks), callbacks=callbacks,
verbose=True, verbose=True,
) )
else: else:
@ -131,7 +132,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
] ]
] ]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(True, description="流式输出"),
model_config: Dict = Body({}, description="LLM 模型配置"), model_config: Dict = Body({}, description="LLM 模型配置"),
tool_config: Dict = Body({}, description="工具配置"), tool_config: Dict = Body({}, description="工具配置"),
): ):
@ -142,11 +143,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
conversation_id=conversation_id conversation_id=conversation_id
) if conversation_id else None ) if conversation_id else None
callback = CustomAsyncIteratorCallbackHandler() callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback] callbacks = [callback]
# 从配置中选择模型 # 从配置中选择模型
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config) models, prompts = create_models_from_config(callbacks=[], configs=model_config, stream=stream)
# 从配置中选择工具 # 从配置中选择工具
tools = [tool for tool in all_tools if tool.name in tool_config] tools = [tool for tool in all_tools if tool.name in tool_config]
@ -164,56 +165,13 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
# Execute Chain # Execute Chain
task = asyncio.create_task( task = asyncio.create_task(
wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done)) wrap_done(full_chain.ainvoke({"input": query}), callback.done))
if stream:
async for chunk in callback.aiter(): async for chunk in callback.aiter():
data = json.loads(chunk) data = json.loads(chunk)
if data["status"] == Status.start: data["message_id"] = message_id
continue yield json.dumps(data, ensure_ascii=False)
elif data["status"] == Status.agent_action:
tool_info = {
"tool_name": data["tool_name"],
"tool_input": data["tool_input"]
}
yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False)
elif data["status"] == Status.agent_finish:
yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id},
ensure_ascii=False)
else:
yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False)
else:
text = ""
agent_finish = ""
tool_info = None
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == Status.agent_action:
tool_info = {
"tool_name": data["tool_name"],
"tool_input": data["tool_input"]
}
if data["status"] == Status.agent_finish:
agent_finish = data["agent_finish"]
else:
text += data["llm_token"]
if tool_info:
yield json.dumps(
{
"text": text,
"agent_action": tool_info,
"agent_finish": agent_finish,
"message_id": message_id
},
ensure_ascii=False
)
else:
yield json.dumps(
{
"text": text,
"message_id": message_id
},
ensure_ascii=False
)
await task await task
return EventSourceResponse(chat_iterator()) return EventSourceResponse(chat_iterator())

View File

@ -2,19 +2,31 @@ import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import asyncio
from server.utils import get_ChatOpenAI from server.utils import get_ChatOpenAI
from server.agent.tools_factory.tools_registry import all_tools from server.agent.tools_factory.tools_registry import all_tools
from server.agent.agent_factory.qwen_agent import initialize_qwen_agent from server.agent.agent_factory.qwen_agent import initialize_qwen_agent
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
from langchain import globals from langchain import globals
globals.set_debug(True) # globals.set_debug(True)
globals.set_verbose(True) # globals.set_verbose(True)
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=False) async def main():
executor = initialize_qwen_agent(tools=all_tools, llm=qwen_model) callback = AgentExecutorAsyncIteratorCallbackHandler()
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
executor = initialize_qwen_agent(tools=all_tools,
llm=qwen_model,
callbacks=[callback],
)
# ret = executor.invoke("苏州今天冷吗") # ret = executor.invoke("苏州今天冷吗")
ret = executor.invoke("从知识库samples中查询chatchat项目简介") ret = asyncio.create_task(executor.ainvoke("苏州今天冷吗"))
# ret = executor.invoke("chatchat项目主要issue有哪些") async for chunk in callback.aiter():
print(ret) print(chunk)
# ret = executor.invoke("从知识库samples中查询chatchat项目简介")
# ret = executor.invoke("chatchat项目主要issue有哪些")
print(ret)
asyncio.run(main())

View File

@ -11,6 +11,7 @@ import os
import re import re
import time import time
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG) from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
from server.callback_handler.agent_callback_handler import AgentStatus
import uuid import uuid
from typing import List, Dict from typing import List, Dict
@ -271,10 +272,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>', f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
unsafe_allow_html=True) unsafe_allow_html=True)
chat_box.ai_say("正在思考...") chat_box.ai_say(["正在思考...",
Markdown(title="tool call", in_expander=True, expanded=True,state="running"),
Markdown()])
text = "" text = ""
message_id = "" message_id = ""
element_index = 0 tool_called = False
for d in api.chat_chat(query=prompt, for d in api.chat_chat(query=prompt,
metadata=files_upload, metadata=files_upload,
@ -283,33 +286,33 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
conversation_id=conversation_id, conversation_id=conversation_id,
tool_config=selected_tool_configs, tool_config=selected_tool_configs,
): ):
try:
d = json.loads(d)
except:
pass
message_id = d.get("message_id", "") message_id = d.get("message_id", "")
metadata = { metadata = {
"message_id": message_id, "message_id": message_id,
} }
if error_msg := check_error_msg(d): if not tool_called: # 避免工具调用之后重复输出将LLM输出分为工具调用前后分别处理
st.error(error_msg) element_index = 0
if chunk := d.get("agent_action"): else:
chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete")) element_index = -1
element_index = 1 if d["status"] == AgentStatus.error:
st.error(d["text"])
elif d["status"] == AgentStatus.agent_action:
formatted_data = { formatted_data = {
"action": chunk["tool_name"], "action": d["tool_name"],
"action_input": chunk["tool_input"] "action_input": d["tool_input"]
} }
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False) formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n" chat_box.update_msg(Markdown(f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"), element_index=1)
chat_box.update_msg(text, element_index=element_index, metadata=metadata) tool_called = True
if chunk := d.get("text"): text = ""
text += chunk elif d["status"] == AgentStatus.llm_new_token:
chat_box.update_msg(text, element_index=element_index, metadata=metadata) text += d["text"]
if chunk := d.get("agent_finish"): chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
element_index = 0 elif d["status"] == AgentStatus.llm_end:
text = chunk chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) elif d["status"] == AgentStatus.agent_finish:
chat_box.update_msg(element_index=1, state="complete", expanded=False)
chat_box.update_msg(Markdown(d["text"]), streaming=False, element_index=-1)
if os.path.exists("tmp/image.jpg"): if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file: with open("tmp/image.jpg", "rb") as image_file: