mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-28 17:53:33 +08:00
fix callback handler
This commit is contained in:
parent
6f04e15aed
commit
e2a46a1d0f
@ -19,7 +19,7 @@ from langchain.output_parsers import OutputFixingParser
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain_core.callbacks import Callbacks
|
||||
@ -169,7 +169,7 @@ class StructuredGLM3ChatAgent(Agent):
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: List[BaseCallbackHandler] = [],
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
@ -187,7 +187,7 @@ class StructuredGLM3ChatAgent(Agent):
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
verbose=True
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
@ -208,6 +208,7 @@ def initialize_glm3_agent(
|
||||
tools: Sequence[BaseTool],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: str = None,
|
||||
callbacks: List[BaseCallbackHandler] = [],
|
||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||
agent_kwargs: Optional[dict] = None,
|
||||
*,
|
||||
@ -216,6 +217,7 @@ def initialize_glm3_agent(
|
||||
) -> AgentExecutor:
|
||||
tags_ = list(tags) if tags else []
|
||||
agent_kwargs = agent_kwargs or {}
|
||||
llm.callbacks=callbacks
|
||||
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
@ -225,6 +227,7 @@ def initialize_glm3_agent(
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
memory=memory,
|
||||
tags=tags_,
|
||||
intermediate_steps=[],
|
||||
|
||||
@ -14,7 +14,7 @@ from langchain.pydantic_v1 import Field
|
||||
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
|
||||
HumanMessage, SystemMessage, AIMessage)
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.render import format_tool_to_openai_function
|
||||
@ -107,16 +107,16 @@ class QwenChatAgent(LLMSingleActionAgent):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: str = None,
|
||||
callbacks: List[BaseCallbackHandler] = [],
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> QwenChatAgent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
@ -129,7 +129,7 @@ class QwenChatAgent(LLMSingleActionAgent):
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
output_parser = output_parser or QwenChatAgentOutputParser()
|
||||
@ -147,34 +147,33 @@ class QwenChatAgent(LLMSingleActionAgent):
|
||||
|
||||
|
||||
def initialize_qwen_agent(
|
||||
tools: Sequence[BaseTool],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: str = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||
agent_kwargs: Optional[dict] = None,
|
||||
*,
|
||||
return_direct: Optional[bool] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
tools: Sequence[BaseTool],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: str = None,
|
||||
callbacks: List[BaseCallbackHandler] = [],
|
||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||
agent_kwargs: Optional[dict] = None,
|
||||
*,
|
||||
return_direct: Optional[bool] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
tags_ = list(tags) if tags else []
|
||||
agent_kwargs = agent_kwargs or {}
|
||||
|
||||
if isinstance(return_direct, bool): # can make all tools return directly
|
||||
tools = [t.copy(update={"return_direct": return_direct}) for t in tools]
|
||||
|
||||
llm.callbacks=callbacks
|
||||
agent_obj = QwenChatAgent.from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
**agent_kwargs
|
||||
**agent_kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
memory=memory,
|
||||
tags=tags_,
|
||||
intermediate_steps=[],
|
||||
|
||||
@ -13,83 +13,50 @@ def dumps(obj: Dict) -> str:
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
|
||||
class Status:
|
||||
start: int = 1
|
||||
running: int = 2
|
||||
complete: int = 3
|
||||
class AgentStatus:
|
||||
llm_start: int = 1
|
||||
llm_new_token: int = 2
|
||||
llm_end: int = 3
|
||||
agent_action: int = 4
|
||||
agent_finish: int = 5
|
||||
error: int = 6
|
||||
tool_finish: int = 7
|
||||
|
||||
|
||||
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.queue = asyncio.Queue()
|
||||
self.done = asyncio.Event()
|
||||
self.cur_tool = {}
|
||||
self.out = True
|
||||
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
self.cur_tool = {
|
||||
"tool_name": serialized["name"],
|
||||
"input_str": input_str,
|
||||
"output_str": "",
|
||||
"status": Status.agent_action,
|
||||
"run_id": run_id.hex,
|
||||
"llm_token": "",
|
||||
"final_answer": "",
|
||||
"error": "",
|
||||
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
data = {
|
||||
"status" : AgentStatus.llm_start,
|
||||
"text" : "",
|
||||
}
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
self.done.clear()
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
|
||||
self.out = True
|
||||
self.cur_tool.update(
|
||||
status=Status.tool_finish,
|
||||
output_str=output.replace("Answer:", ""),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.error,
|
||||
error=str(error),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
special_tokens = ["Action", "<|observation|>"]
|
||||
for stoken in special_tokens:
|
||||
if stoken in token:
|
||||
before_action = token.split(stoken)[0]
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=before_action + "\n",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
data = {
|
||||
"status" : AgentStatus.llm_new_token,
|
||||
"text": before_action + "\n",
|
||||
}
|
||||
self.queue.put_nowait(dumps(data))
|
||||
self.out = False
|
||||
break
|
||||
|
||||
if token is not None and token != "" and self.out:
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=token,
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.start,
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
data = {
|
||||
"status" : AgentStatus.llm_new_token,
|
||||
"text" : token,
|
||||
}
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
@ -102,26 +69,27 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.start,
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
data = {
|
||||
"status" : AgentStatus.llm_start,
|
||||
"text" : "",
|
||||
}
|
||||
self.done.clear()
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.complete,
|
||||
llm_token="",
|
||||
)
|
||||
data = {
|
||||
"status" : AgentStatus.llm_end,
|
||||
"text" : response.generations[0][0].message.content,
|
||||
}
|
||||
self.out = True
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.error,
|
||||
error=str(error),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
data = {
|
||||
"status" : AgentStatus.error,
|
||||
"text" : str(error),
|
||||
}
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
async def on_agent_action(
|
||||
self,
|
||||
@ -132,12 +100,13 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_action,
|
||||
tool_name=action.tool,
|
||||
tool_input=action.tool_input,
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
data = {
|
||||
"status" : AgentStatus.agent_action,
|
||||
"tool_name" : action.tool,
|
||||
"tool_input" : action.tool_input,
|
||||
"text": action.log,
|
||||
}
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||
@ -147,8 +116,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
if "Thought:" in finish.return_values["output"]:
|
||||
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
|
||||
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_finish,
|
||||
agent_finish=finish.return_values["output"],
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
data = {
|
||||
"status" : AgentStatus.agent_finish,
|
||||
"text" : finish.return_values["output"],
|
||||
}
|
||||
self.done.set()
|
||||
self.queue.put_nowait(dumps(data))
|
||||
|
||||
@ -20,22 +20,23 @@ from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||
from server.chat.utils import History
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
||||
from server.callback_handler.agent_callback_handler import AgentStatus, AgentExecutorAsyncIteratorCallbackHandler
|
||||
|
||||
|
||||
def create_models_from_config(configs, callbacks):
|
||||
def create_models_from_config(configs, callbacks, stream):
|
||||
if configs is None:
|
||||
configs = {}
|
||||
models = {}
|
||||
prompts = {}
|
||||
for model_type, model_configs in configs.items():
|
||||
for model_name, params in model_configs.items():
|
||||
callback = callbacks if params.get('callbacks', False) else None
|
||||
callbacks = callbacks if params.get('callbacks', False) else None
|
||||
model_instance = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=params.get('temperature', 0.5),
|
||||
max_tokens=params.get('max_tokens', 1000),
|
||||
callbacks=callback
|
||||
callbacks=callbacks,
|
||||
streaming=stream,
|
||||
)
|
||||
models[model_type] = model_instance
|
||||
prompt_name = params.get('prompt_name', 'default')
|
||||
@ -83,7 +84,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
tools=tools,
|
||||
prompt=prompts["action_model"],
|
||||
memory=memory,
|
||||
# callback_manager=BaseCallbackManager(handlers=callbacks),
|
||||
callbacks=callbacks,
|
||||
verbose=True,
|
||||
)
|
||||
elif "qwen" in models["action_model"].model_name.lower():
|
||||
@ -92,7 +93,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
tools=tools,
|
||||
prompt=prompts["action_model"],
|
||||
memory=memory,
|
||||
# callback_manager=BaseCallbackManager(handlers=callbacks),
|
||||
callbacks=callbacks,
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
@ -131,7 +132,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
]
|
||||
]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
stream: bool = Body(True, description="流式输出"),
|
||||
model_config: Dict = Body({}, description="LLM 模型配置"),
|
||||
tool_config: Dict = Body({}, description="工具配置"),
|
||||
):
|
||||
@ -142,11 +143,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
conversation_id=conversation_id
|
||||
) if conversation_id else None
|
||||
|
||||
callback = CustomAsyncIteratorCallbackHandler()
|
||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
|
||||
# 从配置中选择模型
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||
models, prompts = create_models_from_config(callbacks=[], configs=model_config, stream=stream)
|
||||
|
||||
# 从配置中选择工具
|
||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||
@ -164,56 +165,13 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
# Execute Chain
|
||||
|
||||
task = asyncio.create_task(
|
||||
wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
||||
if stream:
|
||||
async for chunk in callback.aiter():
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.start:
|
||||
continue
|
||||
elif data["status"] == Status.agent_action:
|
||||
tool_info = {
|
||||
"tool_name": data["tool_name"],
|
||||
"tool_input": data["tool_input"]
|
||||
}
|
||||
yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False)
|
||||
elif data["status"] == Status.agent_finish:
|
||||
yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False)
|
||||
else:
|
||||
text = ""
|
||||
agent_finish = ""
|
||||
tool_info = None
|
||||
async for chunk in callback.aiter():
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.agent_action:
|
||||
tool_info = {
|
||||
"tool_name": data["tool_name"],
|
||||
"tool_input": data["tool_input"]
|
||||
}
|
||||
if data["status"] == Status.agent_finish:
|
||||
agent_finish = data["agent_finish"]
|
||||
else:
|
||||
text += data["llm_token"]
|
||||
if tool_info:
|
||||
yield json.dumps(
|
||||
{
|
||||
"text": text,
|
||||
"agent_action": tool_info,
|
||||
"agent_finish": agent_finish,
|
||||
"message_id": message_id
|
||||
},
|
||||
ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
yield json.dumps(
|
||||
{
|
||||
"text": text,
|
||||
"message_id": message_id
|
||||
},
|
||||
ensure_ascii=False
|
||||
)
|
||||
wrap_done(full_chain.ainvoke({"input": query}), callback.done))
|
||||
|
||||
async for chunk in callback.aiter():
|
||||
data = json.loads(chunk)
|
||||
data["message_id"] = message_id
|
||||
yield json.dumps(data, ensure_ascii=False)
|
||||
|
||||
await task
|
||||
|
||||
return EventSourceResponse(chat_iterator())
|
||||
|
||||
@ -2,19 +2,31 @@ import sys
|
||||
from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import asyncio
|
||||
from server.utils import get_ChatOpenAI
|
||||
from server.agent.tools_factory.tools_registry import all_tools
|
||||
from server.agent.agent_factory.qwen_agent import initialize_qwen_agent
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||
from langchain import globals
|
||||
|
||||
globals.set_debug(True)
|
||||
globals.set_verbose(True)
|
||||
# globals.set_debug(True)
|
||||
# globals.set_verbose(True)
|
||||
|
||||
|
||||
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=False)
|
||||
executor = initialize_qwen_agent(tools=all_tools, llm=qwen_model)
|
||||
async def main():
|
||||
callback = AgentExecutorAsyncIteratorCallbackHandler()
|
||||
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
|
||||
executor = initialize_qwen_agent(tools=all_tools,
|
||||
llm=qwen_model,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
# ret = executor.invoke("苏州今天冷吗")
|
||||
ret = executor.invoke("从知识库samples中查询chatchat项目简介")
|
||||
# ret = executor.invoke("chatchat项目主要issue有哪些")
|
||||
print(ret)
|
||||
# ret = executor.invoke("苏州今天冷吗")
|
||||
ret = asyncio.create_task(executor.ainvoke("苏州今天冷吗"))
|
||||
async for chunk in callback.aiter():
|
||||
print(chunk)
|
||||
# ret = executor.invoke("从知识库samples中查询chatchat项目简介")
|
||||
# ret = executor.invoke("chatchat项目主要issue有哪些")
|
||||
print(ret)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@ -11,6 +11,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
|
||||
from server.callback_handler.agent_callback_handler import AgentStatus
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
|
||||
@ -271,10 +272,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
chat_box.ai_say("正在思考...")
|
||||
chat_box.ai_say(["正在思考...",
|
||||
Markdown(title="tool call", in_expander=True, expanded=True,state="running"),
|
||||
Markdown()])
|
||||
text = ""
|
||||
message_id = ""
|
||||
element_index = 0
|
||||
tool_called = False
|
||||
|
||||
for d in api.chat_chat(query=prompt,
|
||||
metadata=files_upload,
|
||||
@ -283,33 +286,33 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
conversation_id=conversation_id,
|
||||
tool_config=selected_tool_configs,
|
||||
):
|
||||
try:
|
||||
d = json.loads(d)
|
||||
except:
|
||||
pass
|
||||
message_id = d.get("message_id", "")
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
if error_msg := check_error_msg(d):
|
||||
st.error(error_msg)
|
||||
if chunk := d.get("agent_action"):
|
||||
chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete"))
|
||||
element_index = 1
|
||||
if not tool_called: # 避免工具调用之后重复输出,将LLM输出分为工具调用前后分别处理
|
||||
element_index = 0
|
||||
else:
|
||||
element_index = -1
|
||||
if d["status"] == AgentStatus.error:
|
||||
st.error(d["text"])
|
||||
elif d["status"] == AgentStatus.agent_action:
|
||||
formatted_data = {
|
||||
"action": chunk["tool_name"],
|
||||
"action_input": chunk["tool_input"]
|
||||
"action": d["tool_name"],
|
||||
"action_input": d["tool_input"]
|
||||
}
|
||||
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||
text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"
|
||||
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
|
||||
if chunk := d.get("text"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
|
||||
if chunk := d.get("agent_finish"):
|
||||
element_index = 0
|
||||
text = chunk
|
||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||
chat_box.update_msg(Markdown(f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"), element_index=1)
|
||||
tool_called = True
|
||||
text = ""
|
||||
elif d["status"] == AgentStatus.llm_new_token:
|
||||
text += d["text"]
|
||||
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
|
||||
elif d["status"] == AgentStatus.llm_end:
|
||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||
elif d["status"] == AgentStatus.agent_finish:
|
||||
chat_box.update_msg(element_index=1, state="complete", expanded=False)
|
||||
chat_box.update_msg(Markdown(d["text"]), streaming=False, element_index=-1)
|
||||
|
||||
if os.path.exists("tmp/image.jpg"):
|
||||
with open("tmp/image.jpg", "rb") as image_file:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user