修复 callback 无效的问题

This commit is contained in:
liunux4odoo 2024-01-10 17:17:47 +08:00
parent 7257521e10
commit 6f155aec1f
4 changed files with 62 additions and 22 deletions

View File

@ -18,12 +18,8 @@ def agents_registry(
callbacks: List[BaseCallbackHandler] = [],
prompt: str = None,
verbose: bool = False):
if prompt is not None:
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
else:
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
# llm.callbacks = callbacks
llm.streaming = False # qwen agent not support streaming
# Write any optimized method here.
if "glm3" in llm.model_name.lower():
@ -32,6 +28,10 @@ def agents_registry(
elif "qwen" in llm.model_name.lower():
agent = create_structured_qwen_chat_agent(llm=llm, tools=tools)
else:
if prompt is not None:
prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)])
else:
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks)

View File

@ -168,8 +168,9 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"status": AgentStatus.agent_finish,
"text": finish.return_values["output"],
}
self.done.set()
self.queue.put_nowait(dumps(data))
async def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
self.done.set()
self.out = True

View File

@ -21,11 +21,6 @@ from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
# from langchain.globals import set_debug, set_verbose
# set_debug(True)
# set_verbose(True)
def create_models_from_config(configs, callbacks, stream):
if configs is None:
configs = {}
@ -128,7 +123,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callback = AgentExecutorAsyncIteratorCallbackHandler()
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=[], configs=model_config, stream=stream)
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, stream=stream)
tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
full_chain = create_models_chains(prompts=prompts,
@ -143,13 +138,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
full_chain.ainvoke(
{
"input": query,
"chat_history": [
HumanMessage(content="今天北京的温度是多少度"),
AIMessage(content="今天北京的温度是1度"),
],
"chat_history": [],
}
),callback.done))
), callback.done))
async for chunk in callback.aiter():
data = json.loads(chunk)

View File

@ -13,10 +13,10 @@ from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIte
from langchain import globals
globals.set_debug(True)
# globals.set_verbose(True)
globals.set_verbose(True)
async def main():
async def test1():
callback = AgentExecutorAsyncIteratorCallbackHandler()
tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools]
# qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback])
@ -38,4 +38,52 @@ async def main():
# ret = executor.invoke("chatchat项目主要issue有哪些")
await ret
asyncio.run(main())
async def test2():
from server.chat.chat import chat
mc={'preprocess_model': {
'qwen': {
'temperature': 0.4,
'max_tokens': 2048,
'history_len': 100,
'prompt_name': 'default',
'callbacks': False}
},
'llm_model': {
'qwen': {
'temperature': 0.9,
'max_tokens': 4096,
'history_len': 3,
'prompt_name': 'default',
'callbacks': True}
},
'action_model': {
'qwen': {
'temperature': 0.01,
'max_tokens': 4096,
'prompt_name': 'qwen',
'callbacks': True}
},
'postprocess_model': {
'qwen': {
'temperature': 0.01,
'max_tokens': 4096,
'prompt_name': 'default',
'callbacks': True}
}
}
tc={'weather_check': {'use': False, 'api-key': 'your key'}}
async for x in (await chat("苏州天气如何",{},
model_config=mc,
tool_config=tc,
conversation_id=None,
history_len=-1,
history=[],
stream=True)).body_iterator:
print(x)
asyncio.run(test2())