fix callback handler

This commit is contained in:
liunux4odoo 2023-12-20 22:42:40 +08:00
parent 6f04e15aed
commit e2a46a1d0f
6 changed files with 143 additions and 198 deletions

View File

@ -19,7 +19,7 @@ from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain_core.callbacks import Callbacks
@ -169,7 +169,7 @@ class StructuredGLM3ChatAgent(Agent):
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: List[BaseCallbackHandler] = [],
output_parser: Optional[AgentOutputParser] = None,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
input_variables: Optional[List[str]] = None,
@ -187,7 +187,7 @@ class StructuredGLM3ChatAgent(Agent):
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
verbose=True
)
tool_names = [tool.name for tool in tools]
@ -208,6 +208,7 @@ def initialize_glm3_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
@ -216,6 +217,7 @@ def initialize_glm3_agent(
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
llm.callbacks=callbacks
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
@ -225,6 +227,7 @@ def initialize_glm3_agent(
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callbacks=callbacks,
memory=memory,
tags=tags_,
intermediate_steps=[],

View File

@ -14,7 +14,7 @@ from langchain.pydantic_v1 import Field
from langchain.schema import (AgentAction, AgentFinish, OutputParserException,
HumanMessage, SystemMessage, AIMessage)
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_function
@ -107,16 +107,16 @@ class QwenChatAgent(LLMSingleActionAgent):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None,
**kwargs: Any,
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
output_parser: Optional[AgentOutputParser] = None,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BaseChatPromptTemplate]] = None,
**kwargs: Any,
) -> QwenChatAgent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
@ -129,7 +129,7 @@ class QwenChatAgent(LLMSingleActionAgent):
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
)
tool_names = [tool.name for tool in tools]
output_parser = output_parser or QwenChatAgentOutputParser()
@ -147,34 +147,33 @@ class QwenChatAgent(LLMSingleActionAgent):
def initialize_qwen_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None,
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
return_direct: Optional[bool] = None,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
callbacks: List[BaseCallbackHandler] = [],
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
return_direct: Optional[bool] = None,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
if isinstance(return_direct, bool): # can make all tools return directly
tools = [t.copy(update={"return_direct": return_direct}) for t in tools]
llm.callbacks=callbacks
agent_obj = QwenChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
prompt=prompt,
callback_manager=callback_manager,
**agent_kwargs
**agent_kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callback_manager=callback_manager,
callbacks=callbacks,
memory=memory,
tags=tags_,
intermediate_steps=[],

View File

@ -13,83 +13,50 @@ def dumps(obj: Dict) -> str:
return json.dumps(obj, ensure_ascii=False)
class Status:
start: int = 1
running: int = 2
complete: int = 3
class AgentStatus:
llm_start: int = 1
llm_new_token: int = 2
llm_end: int = 3
agent_action: int = 4
agent_finish: int = 5
error: int = 6
tool_finish: int = 7
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def __init__(self):
super().__init__()
self.queue = asyncio.Queue()
self.done = asyncio.Event()
self.cur_tool = {}
self.out = True
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None,
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
self.cur_tool = {
"tool_name": serialized["name"],
"input_str": input_str,
"output_str": "",
"status": Status.agent_action,
"run_id": run_id.hex,
"llm_token": "",
"final_answer": "",
"error": "",
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
data = {
"status" : AgentStatus.llm_start,
"text" : "",
}
self.queue.put_nowait(dumps(self.cur_tool))
self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
tags: List[str] | None = None, **kwargs: Any) -> None:
self.out = True
self.cur_tool.update(
status=Status.tool_finish,
output_str=output.replace("Answer:", ""),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.error,
error=str(error),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["Action", "<|observation|>"]
for stoken in special_tokens:
if stoken in token:
before_action = token.split(stoken)[0]
self.cur_tool.update(
status=Status.running,
llm_token=before_action + "\n",
)
self.queue.put_nowait(dumps(self.cur_tool))
data = {
"status" : AgentStatus.llm_new_token,
"text": before_action + "\n",
}
self.queue.put_nowait(dumps(data))
self.out = False
break
if token is not None and token != "" and self.out:
self.cur_tool.update(
status=Status.running,
llm_token=token,
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.start,
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
data = {
"status" : AgentStatus.llm_new_token,
"text" : token,
}
self.queue.put_nowait(dumps(data))
async def on_chat_model_start(
self,
@ -102,26 +69,27 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self.cur_tool.update(
status=Status.start,
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
data = {
"status" : AgentStatus.llm_start,
"text" : "",
}
self.done.clear()
self.queue.put_nowait(dumps(data))
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.complete,
llm_token="",
)
data = {
"status" : AgentStatus.llm_end,
"text" : response.generations[0][0].message.content,
}
self.out = True
self.queue.put_nowait(dumps(self.cur_tool))
self.queue.put_nowait(dumps(data))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.error,
error=str(error),
)
self.queue.put_nowait(dumps(self.cur_tool))
data = {
"status" : AgentStatus.error,
"text" : str(error),
}
self.queue.put_nowait(dumps(data))
async def on_agent_action(
self,
@ -132,12 +100,13 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
self.cur_tool.update(
status=Status.agent_action,
tool_name=action.tool,
tool_input=action.tool_input,
)
self.queue.put_nowait(dumps(self.cur_tool))
data = {
"status" : AgentStatus.agent_action,
"tool_name" : action.tool,
"tool_input" : action.tool_input,
"text": action.log,
}
self.queue.put_nowait(dumps(data))
async def on_agent_finish(
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
@ -147,8 +116,9 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
if "Thought:" in finish.return_values["output"]:
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
self.cur_tool.update(
status=Status.agent_finish,
agent_finish=finish.return_values["output"],
)
self.queue.put_nowait(dumps(self.cur_tool))
data = {
"status" : AgentStatus.agent_finish,
"text" : finish.return_values["output"],
}
self.done.set()
self.queue.put_nowait(dumps(data))

View File

@ -20,22 +20,23 @@ from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.chat.utils import History
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
from server.callback_handler.agent_callback_handler import AgentStatus, AgentExecutorAsyncIteratorCallbackHandler
def create_models_from_config(configs, callbacks):
def create_models_from_config(configs, callbacks, stream):
if configs is None:
configs = {}
models = {}
prompts = {}
for model_type, model_configs in configs.items():
for model_name, params in model_configs.items():
callback = callbacks if params.get('callbacks', False) else None
callbacks = callbacks if params.get('callbacks', False) else None
model_instance = get_ChatOpenAI(
model_name=model_name,
temperature=params.get('temperature', 0.5),
max_tokens=params.get('max_tokens', 1000),
callbacks=callback
callbacks=callbacks,
streaming=stream,
)
models[model_type] = model_instance
prompt_name = params.get('prompt_name', 'default')
@ -83,7 +84,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
tools=tools,
prompt=prompts["action_model"],
memory=memory,
# callback_manager=BaseCallbackManager(handlers=callbacks),
callbacks=callbacks,
verbose=True,
)
elif "qwen" in models["action_model"].model_name.lower():
@ -92,7 +93,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
tools=tools,
prompt=prompts["action_model"],
memory=memory,
# callback_manager=BaseCallbackManager(handlers=callbacks),
callbacks=callbacks,
verbose=True,
)
else:
@ -131,7 +132,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
]
]
),
stream: bool = Body(False, description="流式输出"),
stream: bool = Body(True, description="流式输出"),
model_config: Dict = Body({}, description="LLM 模型配置"),
tool_config: Dict = Body({}, description="工具配置"),
):
@ -142,11 +143,11 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
conversation_id=conversation_id
) if conversation_id else None
callback = CustomAsyncIteratorCallbackHandler()
callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback]
# 从配置中选择模型
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
models, prompts = create_models_from_config(callbacks=[], configs=model_config, stream=stream)
# 从配置中选择工具
tools = [tool for tool in all_tools if tool.name in tool_config]
@ -164,56 +165,13 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
# Execute Chain
task = asyncio.create_task(
wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
if stream:
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == Status.start:
continue
elif data["status"] == Status.agent_action:
tool_info = {
"tool_name": data["tool_name"],
"tool_input": data["tool_input"]
}
yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False)
elif data["status"] == Status.agent_finish:
yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id},
ensure_ascii=False)
else:
yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False)
else:
text = ""
agent_finish = ""
tool_info = None
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == Status.agent_action:
tool_info = {
"tool_name": data["tool_name"],
"tool_input": data["tool_input"]
}
if data["status"] == Status.agent_finish:
agent_finish = data["agent_finish"]
else:
text += data["llm_token"]
if tool_info:
yield json.dumps(
{
"text": text,
"agent_action": tool_info,
"agent_finish": agent_finish,
"message_id": message_id
},
ensure_ascii=False
)
else:
yield json.dumps(
{
"text": text,
"message_id": message_id
},
ensure_ascii=False
)
wrap_done(full_chain.ainvoke({"input": query}), callback.done))
async for chunk in callback.aiter():
data = json.loads(chunk)
data["message_id"] = message_id
yield json.dumps(data, ensure_ascii=False)
await task
return EventSourceResponse(chat_iterator())

View File

@ -2,19 +2,31 @@ import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import asyncio
from server.utils import get_ChatOpenAI
from server.agent.tools_factory.tools_registry import all_tools
from server.agent.agent_factory.qwen_agent import initialize_qwen_agent
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
from langchain import globals
globals.set_debug(True)
globals.set_verbose(True)
# globals.set_debug(True)
# globals.set_verbose(True)
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=False)
executor = initialize_qwen_agent(tools=all_tools, llm=qwen_model)
async def main():
callback = AgentExecutorAsyncIteratorCallbackHandler()
qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
executor = initialize_qwen_agent(tools=all_tools,
llm=qwen_model,
callbacks=[callback],
)
# ret = executor.invoke("苏州今天冷吗")
ret = executor.invoke("从知识库samples中查询chatchat项目简介")
# ret = executor.invoke("chatchat项目主要issue有哪些")
print(ret)
# ret = executor.invoke("苏州今天冷吗")
ret = asyncio.create_task(executor.ainvoke("苏州今天冷吗"))
async for chunk in callback.aiter():
print(chunk)
# ret = executor.invoke("从知识库samples中查询chatchat项目简介")
# ret = executor.invoke("chatchat项目主要issue有哪些")
print(ret)
asyncio.run(main())

View File

@ -11,6 +11,7 @@ import os
import re
import time
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
from server.callback_handler.agent_callback_handler import AgentStatus
import uuid
from typing import List, Dict
@ -271,10 +272,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
unsafe_allow_html=True)
chat_box.ai_say("正在思考...")
chat_box.ai_say(["正在思考...",
Markdown(title="tool call", in_expander=True, expanded=True,state="running"),
Markdown()])
text = ""
message_id = ""
element_index = 0
tool_called = False
for d in api.chat_chat(query=prompt,
metadata=files_upload,
@ -283,33 +286,33 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
conversation_id=conversation_id,
tool_config=selected_tool_configs,
):
try:
d = json.loads(d)
except:
pass
message_id = d.get("message_id", "")
metadata = {
"message_id": message_id,
}
if error_msg := check_error_msg(d):
st.error(error_msg)
if chunk := d.get("agent_action"):
chat_box.insert_msg(Markdown("...", in_expander=True, title="Tools", state="complete"))
element_index = 1
if not tool_called: # 避免工具调用之后重复输出将LLM输出分为工具调用前后分别处理
element_index = 0
else:
element_index = -1
if d["status"] == AgentStatus.error:
st.error(d["text"])
elif d["status"] == AgentStatus.agent_action:
formatted_data = {
"action": chunk["tool_name"],
"action_input": chunk["tool_input"]
"action": d["tool_name"],
"action_input": d["tool_input"]
}
formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
text += f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
if chunk := d.get("text"):
text += chunk
chat_box.update_msg(text, element_index=element_index, metadata=metadata)
if chunk := d.get("agent_finish"):
element_index = 0
text = chunk
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
chat_box.update_msg(Markdown(f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"), element_index=1)
tool_called = True
text = ""
elif d["status"] == AgentStatus.llm_new_token:
text += d["text"]
chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.llm_end:
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.agent_finish:
chat_box.update_msg(element_index=1, state="complete", expanded=False)
chat_box.update_msg(Markdown(d["text"]), streaming=False, element_index=-1)
if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file: