From 6f155aec1fa9972af2ae96623eb8d2b063e36a58 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Wed, 10 Jan 2024 17:17:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20callback=20=E6=97=A0?= =?UTF-8?q?=E6=95=88=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/agent/agent_factory/agents_registry.py | 10 ++-- .../agent_callback_handler.py | 5 +- server/chat/chat.py | 15 ++---- tests/test_qwen_agent.py | 54 +++++++++++++++++-- 4 files changed, 62 insertions(+), 22 deletions(-) diff --git a/server/agent/agent_factory/agents_registry.py b/server/agent/agent_factory/agents_registry.py index a1d4c9c8..e2aabee6 100644 --- a/server/agent/agent_factory/agents_registry.py +++ b/server/agent/agent_factory/agents_registry.py @@ -18,12 +18,8 @@ def agents_registry( callbacks: List[BaseCallbackHandler] = [], prompt: str = None, verbose: bool = False): - if prompt is not None: - prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)]) - else: - prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt - # llm.callbacks = callbacks + llm.streaming = False # qwen agent not support streaming # Write any optimized method here. if "glm3" in llm.model_name.lower(): @@ -32,6 +28,10 @@ def agents_registry( elif "qwen" in llm.model_name.lower(): agent = create_structured_qwen_chat_agent(llm=llm, tools=tools) else: + if prompt is not None: + prompt = ChatPromptTemplate.from_messages([SystemMessage(content=prompt)]) + else: + prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks) diff --git a/server/callback_handler/agent_callback_handler.py b/server/callback_handler/agent_callback_handler.py index 6e845681..2f50020e 100644 --- a/server/callback_handler/agent_callback_handler.py +++ b/server/callback_handler/agent_callback_handler.py @@ -168,8 +168,9 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): "status": AgentStatus.agent_finish, "text": finish.return_values["output"], } - - self.done.set() self.queue.put_nowait(dumps(data)) + + async def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: + self.done.set() self.out = True diff --git a/server/chat/chat.py b/server/chat/chat.py index 72834511..bedc988b 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -21,11 +21,6 @@ from server.db.repository import add_message_to_db from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler -# from langchain.globals import set_debug, set_verbose -# set_debug(True) -# set_verbose(True) - - def create_models_from_config(configs, callbacks, stream): if configs is None: configs = {} @@ -128,7 +123,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callback = AgentExecutorAsyncIteratorCallbackHandler() callbacks = [callback] - models, prompts = create_models_from_config(callbacks=[], configs=model_config, stream=stream) + models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config, stream=stream) tools = [tool for tool in all_tools if tool.name in tool_config] tools = [t.copy(update={"callbacks": callbacks}) for t in tools] full_chain = create_models_chains(prompts=prompts, @@ -143,13 +138,9 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 full_chain.ainvoke( { "input": query, - "chat_history": [ - HumanMessage(content="今天北京的温度是多少度"), - AIMessage(content="今天北京的温度是1度"), - ], - + "chat_history": [], } - ),callback.done)) + ), callback.done)) async for chunk in callback.aiter(): data = json.loads(chunk) diff --git a/tests/test_qwen_agent.py b/tests/test_qwen_agent.py index abbf5de4..2763e103 100644 --- a/tests/test_qwen_agent.py +++ b/tests/test_qwen_agent.py @@ -13,10 +13,10 @@ from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIte from langchain import globals globals.set_debug(True) -# globals.set_verbose(True) +globals.set_verbose(True) -async def main(): +async def test1(): callback = AgentExecutorAsyncIteratorCallbackHandler() tools = [t.copy(update={"callbacks": [callback]}) for t in all_tools] # qwen_model = get_ChatOpenAI("Qwen-1_8B-Chat", 0.01, streaming=True, callbacks=[callback]) @@ -38,4 +38,52 @@ async def main(): # ret = executor.invoke("chatchat项目主要issue有哪些") await ret -asyncio.run(main()) + +async def test2(): + from server.chat.chat import chat + + mc={'preprocess_model': { + 'qwen': { + 'temperature': 0.4, + 'max_tokens': 2048, + 'history_len': 100, + 'prompt_name': 'default', + 'callbacks': False} + }, + 'llm_model': { + 'qwen': { + 'temperature': 0.9, + 'max_tokens': 4096, + 'history_len': 3, + 'prompt_name': 'default', + 'callbacks': True} + }, + 'action_model': { + 'qwen': { + 'temperature': 0.01, + 'max_tokens': 4096, + 'prompt_name': 'qwen', + 'callbacks': True} + }, + 'postprocess_model': { + 'qwen': { + 'temperature': 0.01, + 'max_tokens': 4096, + 'prompt_name': 'default', + 'callbacks': True} + } + } + + tc={'weather_check': {'use': False, 'api-key': 'your key'}} + + async for x in (await chat("苏州天气如何",{}, + model_config=mc, + tool_config=tc, + conversation_id=None, + history_len=-1, + history=[], + stream=True)).body_iterator: + print(x) + + +asyncio.run(test2())