mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-27 01:03:35 +08:00
- 修复:
- Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回
- CallbackHandler 正确处理工具调用信息
- 重写 tool 定义方式:
- 添加 regist_tool 简化 tool 定义:
- 可以指定一个用户友好的名称
- 自动将函数的 __doc__ 作为 tool.description
- 支持用 Field 定义参数,不再需要额外定义 ModelSchema
- 添加 BaseToolOutput 封装 tool 返回结果,以便同时获取原始值、给LLM的字符串值
- 支持工具热加载(有待测试)
- 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持:
- Agent 对话
- 指定工具调用(如知识库RAG)
- LLM 对话
- 根据后端功能更新 webui
144 lines
4.3 KiB
Python
144 lines
4.3 KiB
Python
import json
|
||
import re
|
||
from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
|
||
|
||
from langchain.agents import tool
|
||
from langchain_core.tools import BaseTool
|
||
|
||
from chatchat.server.pydantic_v1 import BaseModel, Extra
|
||
|
||
|
||
__all__ = ["regist_tool", "BaseToolOutput"]
|
||
|
||
|
||
_TOOLS_REGISTRY = {}
|
||
|
||
|
||
# patch BaseTool to support extra fields e.g. a title
|
||
BaseTool.Config.extra = Extra.allow
|
||
|
||
################################### TODO: workaround to langchain #15855
|
||
# patch BaseTool to support tool parameters defined using pydantic Field
|
||
|
||
def _new_parse_input(
|
||
self,
|
||
tool_input: Union[str, Dict],
|
||
) -> Union[str, Dict[str, Any]]:
|
||
"""Convert tool input to pydantic model."""
|
||
input_args = self.args_schema
|
||
if isinstance(tool_input, str):
|
||
if input_args is not None:
|
||
key_ = next(iter(input_args.__fields__.keys()))
|
||
input_args.validate({key_: tool_input})
|
||
return tool_input
|
||
else:
|
||
if input_args is not None:
|
||
result = input_args.parse_obj(tool_input)
|
||
return result.dict()
|
||
|
||
|
||
def _new_to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||
# For backwards compatibility, if run_input is a string,
|
||
# pass as a positional argument.
|
||
if isinstance(tool_input, str):
|
||
return (tool_input,), {}
|
||
else:
|
||
# for tool defined with `*args` parameters
|
||
# the args_schema has a field named `args`
|
||
# it should be expanded to actual *args
|
||
# e.g.: test_tools
|
||
# .test_named_tool_decorator_return_direct
|
||
# .search_api
|
||
if "args" in tool_input:
|
||
args = tool_input["args"]
|
||
if args is None:
|
||
tool_input.pop("args")
|
||
return (), tool_input
|
||
elif isinstance(args, tuple):
|
||
tool_input.pop("args")
|
||
return args, tool_input
|
||
return (), tool_input
|
||
|
||
|
||
BaseTool._parse_input = _new_parse_input
|
||
BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
|
||
###############################
|
||
|
||
|
||
def regist_tool(
|
||
*args: Any,
|
||
title: str = "",
|
||
description: str = "",
|
||
return_direct: bool = False,
|
||
args_schema: Optional[Type[BaseModel]] = None,
|
||
infer_schema: bool = True,
|
||
) -> Union[Callable, BaseTool]:
|
||
'''
|
||
wrapper of langchain tool decorator
|
||
add tool to regstiry automatically
|
||
'''
|
||
def _parse_tool(t: BaseTool):
|
||
nonlocal description, title
|
||
|
||
_TOOLS_REGISTRY[t.name] = t
|
||
|
||
# change default description
|
||
if not description:
|
||
if t.func is not None:
|
||
description = t.func.__doc__
|
||
elif t.coroutine is not None:
|
||
description = t.coroutine.__doc__
|
||
t.description = " ".join(re.split(r"\n+\s*", description))
|
||
# set a default title for human
|
||
if not title:
|
||
title = "".join([x.capitalize() for x in t.name.split("_")])
|
||
t.title = title
|
||
|
||
def wrapper(def_func: Callable) -> BaseTool:
|
||
partial_ = tool(*args,
|
||
return_direct=return_direct,
|
||
args_schema=args_schema,
|
||
infer_schema=infer_schema,
|
||
)
|
||
t = partial_(def_func)
|
||
_parse_tool(t)
|
||
return t
|
||
|
||
if len(args) == 0:
|
||
return wrapper
|
||
else:
|
||
t = tool(*args,
|
||
return_direct=return_direct,
|
||
args_schema=args_schema,
|
||
infer_schema=infer_schema,
|
||
)
|
||
_parse_tool(t)
|
||
return t
|
||
|
||
|
||
class BaseToolOutput:
|
||
'''
|
||
LLM 要求 Tool 的输出为 str,但 Tool 用在别处时希望它正常返回结构化数据。
|
||
只需要将 Tool 返回值用该类封装,能同时满足两者的需要。
|
||
基类简单的将返回值字符串化,或指定 format="json" 将其转为 json。
|
||
用户也可以继承该类定义自己的转换方法。
|
||
'''
|
||
def __init__(
|
||
self,
|
||
data: Any,
|
||
format: str="",
|
||
data_alias: str="",
|
||
**extras: Any,
|
||
) -> None:
|
||
self.data = data
|
||
self.format = format
|
||
self.extras = extras
|
||
if data_alias:
|
||
setattr(self, data_alias, property(lambda obj: obj.data))
|
||
|
||
def __str__(self) -> str:
|
||
if self.format == "json":
|
||
return json.dumps(self.data, ensure_ascii=False, indent=2)
|
||
else:
|
||
return str(self.data)
|