agent代码优化

This commit is contained in:
glide-the 2024-05-09 17:32:02 +08:00
parent 8be50af990
commit 402153de09
2 changed files with 10 additions and 4 deletions

View File

@ -25,7 +25,10 @@ def agents_registry(
if "glm3" in llm.model_name.lower():
# An optimized method of langchain Agent that uses the glm3 series model
agent = create_structured_glm3_chat_agent(llm=llm, tools=tools)
# pass
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks)
return agent_executor
elif "qwen" in llm.model_name.lower():
return create_structured_qwen_chat_agent(llm=llm, tools=tools, callbacks=callbacks)
else:
@ -35,6 +38,6 @@ def agents_registry(
prompt = hub.pull("hwchase17/structured-chat-agent") # default prompt
agent = create_structured_chat_agent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, callbacks=callbacks)
return agent_executor
return agent_executor

View File

@ -246,7 +246,10 @@ def get_OpenAIClient(
construct an openai Client for specified platform or model
'''
if platform_name is None:
platform_name = get_model_info(model_name=model_name, platform_name=platform_name)["platform_name"]
platform_info = get_model_info(model_name=model_name, platform_name=platform_name)
if platform_info is None:
raise RuntimeError(f"cannot find configured platform for model: {model_name}")
platform_name = platform_info.get("platform_name")
platform_info = get_config_platforms().get(platform_name)
assert platform_info, f"cannot find configured platform: {platform_name}"
params = {