From 402153de09c2c749b754ec7cfee9c4813ec997f4 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Thu, 9 May 2024 17:32:02 +0800 Subject: [PATCH] =?UTF-8?q?agent=E4=BB=A3=E7=A0=81=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../server/agent/agent_factory/agents_registry.py | 9 ++++++--- chatchat-server/chatchat/server/utils.py | 5 ++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/chatchat-server/chatchat/server/agent/agent_factory/agents_registry.py b/chatchat-server/chatchat/server/agent/agent_factory/agents_registry.py index 7c9a5c0a..7de96aaa 100644 --- a/chatchat-server/chatchat/server/agent/agent_factory/agents_registry.py +++ b/chatchat-server/chatchat/server/agent/agent_factory/agents_registry.py @@ -25,7 +25,10 @@ def agents_registry( if "glm3" in llm.model_name.lower(): # An optimized method of langchain Agent that uses the glm3 series model agent = create_structured_glm3_chat_agent(llm=llm, tools=tools) - # pass + + agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks) + + return agent_executor elif "qwen" in llm.model_name.lower(): return create_structured_qwen_chat_agent(llm=llm, tools=tools, callbacks=callbacks) else: @@ -35,6 +38,6 @@ def agents_registry( prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt) - agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks) + agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks) - return agent_executor + return agent_executor diff --git a/chatchat-server/chatchat/server/utils.py b/chatchat-server/chatchat/server/utils.py index 8efcd0f8..ac441157 100644 --- a/chatchat-server/chatchat/server/utils.py +++ b/chatchat-server/chatchat/server/utils.py @@ -246,7 +246,10 @@ def get_OpenAIClient( construct an openai Client for specified platform or model ''' if platform_name is None: - platform_name = get_model_info(model_name=model_name, platform_name=platform_name)["platform_name"] + platform_info = get_model_info(model_name=model_name, platform_name=platform_name) + if platform_info is None: + raise RuntimeError(f"cannot find configured platform for model: {model_name}") + platform_name = platform_info.get("platform_name") platform_info = get_config_platforms().get(platform_name) assert platform_info, f"cannot find configured platform: {platform_name}" params = {