liunux4odoo 42aa900566
优化工具定义;添加 openai 兼容的统一 chat 接口 (#3570)
- 修复:
    - Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回
    - CallbackHandler 正确处理工具调用信息

- 重写 tool 定义方式:
    - 添加 regist_tool 简化 tool 定义:
        - 可以指定一个用户友好的名称
        - 自动将函数的 __doc__ 作为 tool.description
	- 支持用 Field 定义参数,不再需要额外定义 ModelSchema
        - 添加 BaseToolOutput 封装 tool	返回结果,以便同时获取原始值、给LLM的字符串值
        - 支持工具热加载(有待测试)

- 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持:
    - Agent 对话
    - 指定工具调用(如知识库RAG)
    - LLM 对话

- 根据后端功能更新 webui
2024-03-29 11:55:32 +08:00

144 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import re
from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
from langchain.agents import tool
from langchain_core.tools import BaseTool
from chatchat.server.pydantic_v1 import BaseModel, Extra
__all__ = ["regist_tool", "BaseToolOutput"]
_TOOLS_REGISTRY = {}
# patch BaseTool to support extra fields e.g. a title
BaseTool.Config.extra = Extra.allow
################################### TODO: workaround to langchain #15855
# patch BaseTool to support tool parameters defined using pydantic Field
def _new_parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
return tool_input
else:
if input_args is not None:
result = input_args.parse_obj(tool_input)
return result.dict()
def _new_to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
# for tool defined with `*args` parameters
# the args_schema has a field named `args`
# it should be expanded to actual *args
# e.g.: test_tools
# .test_named_tool_decorator_return_direct
# .search_api
if "args" in tool_input:
args = tool_input["args"]
if args is None:
tool_input.pop("args")
return (), tool_input
elif isinstance(args, tuple):
tool_input.pop("args")
return args, tool_input
return (), tool_input
BaseTool._parse_input = _new_parse_input
BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
###############################
def regist_tool(
*args: Any,
title: str = "",
description: str = "",
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Union[Callable, BaseTool]:
'''
wrapper of langchain tool decorator
add tool to regstiry automatically
'''
def _parse_tool(t: BaseTool):
nonlocal description, title
_TOOLS_REGISTRY[t.name] = t
# change default description
if not description:
if t.func is not None:
description = t.func.__doc__
elif t.coroutine is not None:
description = t.coroutine.__doc__
t.description = " ".join(re.split(r"\n+\s*", description))
# set a default title for human
if not title:
title = "".join([x.capitalize() for x in t.name.split("_")])
t.title = title
def wrapper(def_func: Callable) -> BaseTool:
partial_ = tool(*args,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
t = partial_(def_func)
_parse_tool(t)
return t
if len(args) == 0:
return wrapper
else:
t = tool(*args,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
_parse_tool(t)
return t
class BaseToolOutput:
'''
LLM 要求 Tool 的输出为 str但 Tool 用在别处时希望它正常返回结构化数据。
只需要将 Tool 返回值用该类封装,能同时满足两者的需要。
基类简单的将返回值字符串化,或指定 format="json" 将其转为 json。
用户也可以继承该类定义自己的转换方法。
'''
def __init__(
self,
data: Any,
format: str="",
data_alias: str="",
**extras: Any,
) -> None:
self.data = data
self.format = format
self.extras = extras
if data_alias:
setattr(self, data_alias, property(lambda obj: obj.data))
def __str__(self) -> str:
if self.format == "json":
return json.dumps(self.data, ensure_ascii=False, indent=2)
else:
return str(self.data)