- 重写 tool 部分: (#3553)

- 简化 tool 的定义方式
    - 所有 tool 和 tool_config 支持热加载
    - 修复:json_schema_extra warning
This commit is contained in:
liunux4odoo 2024-03-28 13:08:51 +08:00 committed by GitHub
parent f9f9d4b9fb
commit 9818bd2a88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 348 additions and 280 deletions

View File

@ -12,12 +12,12 @@ langchain.verbose = False
# 通常情况下不需要更改以下内容
# 用户数据根目录
DATA_PATH = (Path(__file__).absolute().parent.parent) # / "data")
DATA_PATH = str(Path(__file__).absolute().parent.parent / "data")
if not os.path.exists(DATA_PATH):
os.mkdir(DATA_PATH)
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(DATA_PATH, "data/nltk_data")
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -29,12 +29,12 @@ logging.basicConfig(format=LOG_FORMAT)
# 日志存储路径
LOG_PATH = os.path.join(DATA_PATH, "data/logs")
LOG_PATH = os.path.join(DATA_PATH, "logs")
if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH)
# 模型生成内容(图片、视频、音频等)保存位置
MEDIA_PATH = os.path.join(DATA_PATH, "data/media")
MEDIA_PATH = os.path.join(DATA_PATH, "media")
if not os.path.exists(MEDIA_PATH):
os.mkdir(MEDIA_PATH)
os.mkdir(os.path.join(MEDIA_PATH, "image"))
@ -42,6 +42,6 @@ if not os.path.exists(MEDIA_PATH):
os.mkdir(os.path.join(MEDIA_PATH, "video"))
# 临时文件目录,主要用于文件对话
BASE_TEMP_DIR = os.path.join(DATA_PATH, "data/temp")
BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
if not os.path.exists(BASE_TEMP_DIR):
os.mkdir(BASE_TEMP_DIR)

View File

@ -1,6 +1,6 @@
import os
from chatchat.configs.basic_config import DATA_PATH
from .basic_config import DATA_PATH
# 默认使用的知识库
@ -51,7 +51,7 @@ KB_INFO = {
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(DATA_PATH, "data/knowledge_base")
KB_ROOT_PATH = os.path.join(DATA_PATH, "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)

View File

@ -151,6 +151,7 @@ MODEL_PLATFORMS = [
"image_models": [],
"multimodal_models": [],
},
{
"platform_name": "ollama",
"platform_type": "ollama",
@ -168,6 +169,7 @@ MODEL_PLATFORMS = [
"image_models": [],
"multimodal_models": [],
},
# {
# "platform_name": "loom",
# "platform_type": "loom",
@ -184,3 +186,89 @@ MODEL_PLATFORMS = [
]
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")
# 工具配置项
TOOL_CONFIG = {
"search_local_knowledgebase": {
"use": False,
"top_k": 3,
"score_threshold": 1,
"conclude_prompt": {
"with_result":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"'
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"without_result":
'请你根据我的提问回答我的问题:\n'
'{{ question }}\n'
'请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n',
}
},
"search_internet": {
"use": False,
"search_engine_name": "bing",
"search_engine_config":
{
"bing": {
"result_len": 3,
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
"bing_key": "",
},
"metaphor": {
"result_len": 3,
"metaphor_api_key": "",
"split_result": False,
"chunk_size": 500,
"chunk_overlap": 0,
},
"duckduckgo": {
"result_len": 3
}
},
"top_k": 10,
"verbose": "Origin",
"conclude_prompt":
"<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
"</指令>\n<已知信息>{{ context }}</已知信息>\n"
"<问题>\n"
"{{ question }}\n"
"</问题>\n"
},
"arxiv": {
"use": False,
},
"shell": {
"use": False,
},
"weather_check": {
"use": False,
"api_key": "S8vrB4U_-c5mvAMiK",
},
"search_youtube": {
"use": False,
},
"wolfram": {
"use": False,
"appid": "",
},
"calculate": {
"use": False,
},
"vqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "your tokenizer path",
"device": "cuda:1"
},
"aqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "yout tokenizer path",
"device": "cuda:2"
},
"text2images": {
"use": False,
},
}

View File

@ -83,7 +83,7 @@ PROMPT_TEMPLATES = {
'Question: {input}\n\n'
'{agent_scratchpad}\n',
"qwen":
'Answer the following question as best you can. You have access to the following APIs:\n\n'
'Answer the following questions as best you can. You have access to the following APIs:\n\n'
'{tools}\n\n'
'Use the following format:\n\n'
'Question: the input question you must answer\n'
@ -122,88 +122,3 @@ PROMPT_TEMPLATES = {
"default": "{{input}}",
}
}
TOOL_CONFIG = {
"search_local_knowledgebase": {
"use": False,
"top_k": 3,
"score_threshold": 1,
"conclude_prompt": {
"with_result":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"'
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"without_result":
'请你根据我的提问回答我的问题:\n'
'{{ question }}\n'
'请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n',
}
},
"search_internet": {
"use": False,
"search_engine_name": "bing",
"search_engine_config":
{
"bing": {
"result_len": 3,
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
"bing_key": "",
},
"metaphor": {
"result_len": 3,
"metaphor_api_key": "",
"split_result": False,
"chunk_size": 500,
"chunk_overlap": 0,
},
"duckduckgo": {
"result_len": 3
}
},
"top_k": 10,
"verbose": "Origin",
"conclude_prompt":
"<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
"</指令>\n<已知信息>{{ context }}</已知信息>\n"
"<问题>\n"
"{{ question }}\n"
"</问题>\n"
},
"arxiv": {
"use": False,
},
"shell": {
"use": False,
},
"weather_check": {
"use": False,
"api-key": "S8vrB4U_-c5mvAMiK",
},
"search_youtube": {
"use": False,
},
"wolfram": {
"use": False,
},
"calculate": {
"use": False,
},
"vqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "your tokenizer path",
"device": "cuda:1"
},
"aqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "yout tokenizer path",
"device": "cuda:2"
},
"text2images": {
"use": False,
},
}

View File

@ -23,4 +23,3 @@ API_SERVER = {
"host": DEFAULT_BIND_HOST,
"port": 7861,
}

View File

@ -1,4 +1,5 @@
from chatchat.configs import TOOL_CONFIG, logger
from chatchat.configs import logger
from chatchat.server.utils import get_tool_config
class ModelContainer:
@ -11,36 +12,38 @@ class ModelContainer:
self.audio_tokenizer = None
self.audio_model = None
if TOOL_CONFIG["vqa_processor"]["use"]:
vqa_config = get_tool_config("vqa_processor")
if vqa_config["use"]:
try:
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
import torch
self.vision_tokenizer = LlamaTokenizer.from_pretrained(
TOOL_CONFIG["vqa_processor"]["tokenizer_path"],
vqa_config["tokenizer_path"],
trust_remote_code=True)
self.vision_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"],
pretrained_model_name_or_path=vqa_config["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(TOOL_CONFIG["vqa_processor"]["device"]).eval()
).to(vqa_config["device"]).eval()
except Exception as e:
logger.error(e, exc_info=True)
if TOOL_CONFIG["aqa_processor"]["use"]:
aqa_config = get_tool_config("vqa_processor")
if aqa_config["use"]:
try:
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
import torch
self.audio_tokenizer = AutoTokenizer.from_pretrained(
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
aqa_config["tokenizer_path"],
trust_remote_code=True
)
self.audio_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
pretrained_model_name_or_path=aqa_config["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).to(
TOOL_CONFIG["aqa_processor"]["device"]
aqa_config["device"]
).eval()
except Exception as e:
logger.error(e, exc_info=True)

View File

@ -1,12 +1,12 @@
from .search_local_knowledgebase import search_local_knowledgebase, SearchKnowledgeInput
from .search_local_knowledgebase import search_local_knowledgebase
from .calculate import calculate
from .weather_check import weather_check, WeatherInput
from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput
from .search_youtube import search_youtube, YoutubeInput
from .arxiv import arxiv, ArxivInput
from .weather_check import weather_check
from .shell import shell
from .search_internet import search_internet
from .wolfram import wolfram
from .search_youtube import search_youtube
from .arxiv import arxiv
from .text2image import text2images
from .vision_factory import *
from .audio_factory import *
from .vqa_processor import vqa_processor
from .aqa_processor import aqa_processor

View File

@ -1,6 +1,8 @@
import base64
import os
from chatchat.server.pydantic_v1 import BaseModel, Field
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
def save_base64_audio(base64_audio, file_path):
audio_data = base64.b64decode(base64_audio)
@ -14,7 +16,10 @@ def aqa_run(model, tokenizer, query):
return response
def aqa_processor(query: str):
@regist_tool
def aqa_processor(query: str = Field(description="The question of the image in English")):
'''use this tool to get answer for audio question'''
from chatchat.server.agent.container import container
if container.metadata["audios"]:
file_path = "temp_audio.mp3"
@ -26,6 +31,3 @@ def aqa_processor(query: str):
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
else:
return "No Audio, Please Try Again"
class AQAInput(BaseModel):
query: str = Field(description="The question of the image in English")

View File

@ -1,11 +1,12 @@
# LangChain 的 ArxivQueryRun 工具
from chatchat.server.pydantic_v1 import BaseModel, Field
from langchain.tools.arxiv.tool import ArxivQueryRun
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
def arxiv(query: str):
@regist_tool
def arxiv(query: str = Field(description="The search query title")):
'''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.'''
from langchain.tools.arxiv.tool import ArxivQueryRun
tool = ArxivQueryRun()
return tool.run(tool_input=query)
class ArxivInput(BaseModel):
query: str = Field(description="The search query title")

View File

@ -1 +0,0 @@
from .aqa import aqa_processor, AQAInput

View File

@ -1,8 +1,9 @@
from langchain.agents import tool
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
@tool
def calculate(text: str) -> float:
@regist_tool
def calculate(text: str = Field(description="a math expression")) -> float:
'''
Useful to answer questions about simple calculations.
translate user question to a math expression that can be evaluated by numexpr.

View File

@ -1,12 +1,15 @@
from chatchat.server.pydantic_v1 import BaseModel, Field
from typing import List, Dict
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from chatchat.configs import TOOL_CONFIG
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Dict
from langchain.docstore.document import Document
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
def bing_search(text, config):
@ -90,10 +93,10 @@ def search_engine(query: str,
for doc in docs:
context += doc + "\n"
return context
def search_internet(query: str):
tool_config = TOOL_CONFIG["search_internet"]
@regist_tool
def search_internet(query: str = Field(description="query for Internet search")):
'''Use this tool to use bing search engine to search the internet and get information.'''
tool_config = get_tool_config("search_internet")
return search_engine(query=query, config=tool_config)
class SearchInternetInput(BaseModel):
query: str = Field(description="query for Internet search")

View File

@ -1,8 +1,14 @@
from urllib.parse import urlencode
from chatchat.server.pydantic_v1 import BaseModel, Field
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from chatchat.server.knowledge_base.kb_doc_api import search_docs
from chatchat.configs import TOOL_CONFIG
from chatchat.configs import KB_INFO
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
def search_knowledgebase(query: str, database: str, config: dict):
@ -29,11 +35,11 @@ def search_knowledgebase(query: str, database: str, config: dict):
return context
class SearchKnowledgeInput(BaseModel):
database: str = Field(description="Database for Knowledge Search")
query: str = Field(description="Query for Knowledge Search")
def search_local_knowledgebase(database: str, query: str):
tool_config = TOOL_CONFIG["search_local_knowledgebase"]
@regist_tool(description=template_knowledge)
def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search"),
query: str = Field(description="Query for Knowledge Search"),
):
''''''
tool_config = get_tool_config("search_local_knowledgebase")
return search_knowledgebase(query=query, database=database, config=tool_config)

View File

@ -1,10 +1,10 @@
from langchain_community.tools import YouTubeSearchTool
from chatchat.server.pydantic_v1 import BaseModel, Field
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
def search_youtube(query: str):
@regist_tool
def search_youtube(query: str = Field(description="Query for Videos search")):
'''use this tools_factory to search youtube videos'''
from langchain_community.tools import YouTubeSearchTool
tool = YouTubeSearchTool()
return tool.run(tool_input=query)
class YoutubeInput(BaseModel):
query: str = Field(description="Query for Videos search")

View File

@ -1,11 +1,12 @@
# LangChain 的 Shell 工具
from chatchat.server.pydantic_v1 import BaseModel, Field
from langchain_community.tools import ShellTool
from langchain.tools.shell import ShellTool
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
def shell(query: str):
@regist_tool
def shell(query: str = Field(description="The command to execute")):
'''Use Shell to execute system shell commands'''
tool = ShellTool()
return tool.run(tool_input=query)
class ShellInput(BaseModel):
query: str = Field(description="The command to execute")

View File

@ -5,8 +5,9 @@ from PIL import Image
from typing import List
import uuid
from langchain.agents import tool
from chatchat.server.pydantic_v1 import Field, FieldInfo
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
import openai
from chatchat.configs.basic_config import MEDIA_PATH
@ -25,7 +26,7 @@ def get_image_model_config() -> dict:
return config
@tool(return_direct=True)
@regist_tool(return_direct=True)
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
@ -33,13 +34,6 @@ def text2images(
height: int = Field(512, description="生成图片的高度"),
) -> List[str]:
'''根据用户的描述生成图片'''
# workaround before langchain uprading
if isinstance(n, FieldInfo):
n = n.default
if isinstance(width, FieldInfo):
width = width.default
if isinstance(height, FieldInfo):
height = height.default
model_config = get_image_model_config()
assert model_config is not None, "请正确配置文生图模型"

View File

@ -1,69 +1,103 @@
from langchain_core.tools import StructuredTool
from chatchat.server.agent.tools_factory import *
from chatchat.configs import KB_INFO
import re
from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
from langchain.agents import tool
from langchain_core.tools import BaseTool
all_tools = [
calculate,
StructuredTool.from_function(
func=arxiv,
name="arxiv",
description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
args_schema=ArxivInput,
),
StructuredTool.from_function(
func=shell,
name="shell",
description="Use Shell to execute Linux commands",
args_schema=ShellInput,
),
StructuredTool.from_function(
func=wolfram,
name="wolfram",
description="Useful for when you need to calculate difficult formulas",
args_schema=WolframInput,
from chatchat.server.pydantic_v1 import BaseModel
),
StructuredTool.from_function(
func=search_youtube,
name="search_youtube",
description="use this tools_factory to search youtube videos",
args_schema=YoutubeInput,
),
StructuredTool.from_function(
func=weather_check,
name="weather_check",
description="Use this tool to check the weather at a specific location",
args_schema=WeatherInput,
),
StructuredTool.from_function(
func=search_internet,
name="search_internet",
description="Use this tool to use bing search engine to search the internet and get information",
args_schema=SearchInternetInput,
),
StructuredTool.from_function(
func=search_local_knowledgebase,
name="search_local_knowledgebase",
description=template_knowledge,
args_schema=SearchKnowledgeInput,
),
StructuredTool.from_function(
func=vqa_processor,
name="vqa_processor",
description="use this tool to get answer for image question",
args_schema=VQAInput,
),
StructuredTool.from_function(
func=aqa_processor,
name="aqa_processor",
description="use this tool to get answer for audio question",
args_schema=AQAInput,
)
]
_TOOLS_REGISTRY = {}
all_tools.append(text2images)
################################### TODO: workaround to langchain #15855
# patch BaseTool to support tool parameters defined using pydantic Field
def _new_parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
return tool_input
else:
if input_args is not None:
result = input_args.parse_obj(tool_input)
return result.dict()
def _new_to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
# for tool defined with `*args` parameters
# the args_schema has a field named `args`
# it should be expanded to actual *args
# e.g.: test_tools
# .test_named_tool_decorator_return_direct
# .search_api
if "args" in tool_input:
args = tool_input["args"]
if args is None:
tool_input.pop("args")
return (), tool_input
elif isinstance(args, tuple):
tool_input.pop("args")
return args, tool_input
return (), tool_input
BaseTool._parse_input = _new_parse_input
BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
###############################
def regist_tool(
*args: Any,
description: str = "",
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Union[Callable, BaseTool]:
'''
wrapper of langchain tool decorator
add tool to regstiry automatically
'''
def _parse_tool(t: BaseTool):
nonlocal description
_TOOLS_REGISTRY[t.name] = t
# change default description
if not description:
if t.func is not None:
description = t.func.__doc__
elif t.coroutine is not None:
description = t.coroutine.__doc__
t.description = " ".join(re.split(r"\n+\s*", description))
def wrapper(def_func: Callable) -> BaseTool:
partial_ = tool(*args,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
t = partial_(def_func)
_parse_tool(t)
return t
if len(args) == 0:
return wrapper
else:
t = tool(*args,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
_parse_tool(t)
return t

View File

@ -1 +0,0 @@
from .vqa import vqa_processor,VQAInput

View File

@ -4,8 +4,9 @@ Method Use cogagent to generate response for a given image and query.
import base64
from io import BytesIO
from PIL import Image, ImageDraw
from chatchat.server.pydantic_v1 import BaseModel, Field
from chatchat.configs import TOOL_CONFIG
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
import re
from chatchat.server.agent.container import container
@ -98,8 +99,11 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
return response
def vqa_processor(query: str):
tool_config = TOOL_CONFIG["vqa_processor"]
@regist_tool
def vqa_processor(query: str = Field(description="The question of the image in English")):
'''use this tool to get answer for image question'''
tool_config = get_tool_config("vqa_processor")
if container.metadata["images"]:
image_base64 = container.metadata["images"][0]
ans = vqa_run(model=container.vision_model,
@ -118,7 +122,3 @@ def vqa_processor(query: str):
return ans
else:
return "No Image, Please Try Again"
class VQAInput(BaseModel):
query: str = Field(description="The question of the image in English")

View File

@ -1,11 +1,19 @@
"""
简单的单参数输入工具实现用于查询现在天气的情况
"""
from chatchat.server.pydantic_v1 import BaseModel, Field
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
import requests
def weather(location: str, api_key: str):
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
@regist_tool
def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")):
'''Use this tool to check the weather at a specific city'''
tool_config = get_tool_config("weather_check")
api_key = tool_config.get("api_key")
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={city}&language=zh-Hans&unit=c"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
@ -17,9 +25,3 @@ def weather(location: str, api_key: str):
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
def weather_check(location: str):
return weather(location, "S8vrB4U_-c5mvAMiK")
class WeatherInput(BaseModel):
location: str = Field(description="City name,include city and county,like '厦门'")

View File

@ -1,13 +1,16 @@
# Langchain 自带的 Wolfram Alpha API 封装
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from chatchat.server.pydantic_v1 import BaseModel, Field
wolfram_alpha_appid = "your key"
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
def wolfram(query: str):
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid)
@regist_tool
def wolfram(query: str = Field(description="The formula to be calculated")):
'''Useful for when you need to calculate difficult formulas'''
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid"))
ans = wolfram.run(query)
return ans
class WolframInput(BaseModel):
formula: str = Field(description="The formula to be calculated")

View File

@ -5,7 +5,7 @@ from typing import List
from fastapi import APIRouter, Request, Body
from chatchat.configs import logger
from chatchat.server.utils import BaseResponse
from chatchat.server.utils import BaseResponse, get_tool
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@ -13,11 +13,8 @@ tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@tool_router.get("/", response_model=BaseResponse)
async def list_tools():
import importlib
from chatchat.server.agent.tools_factory import tools_registry
importlib.reload(tools_registry)
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools_registry.all_tools}
tools = get_tool()
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools}
return {"data": data}
@ -26,12 +23,9 @@ async def call_tool(
name: str = Body(examples=["calculate"]),
kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]),
):
import importlib
from chatchat.server.agent.tools_factory import tools_registry
importlib.reload(tools_registry)
tools = get_tool()
tool_names = {t.name: t for t in tools_registry.all_tools}
if tool := tool_names.get(name):
if tool := tools.get(name):
try:
result = await tool.ainvoke(kwargs)
return {"data": result}

View File

@ -11,10 +11,9 @@ from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts import PromptTemplate
from chatchat.server.agent.agent_factory.agents_registry import agents_registry
from chatchat.server.agent.tools_factory.tools_registry import all_tools
from chatchat.server.agent.container import container
from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType
from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool
from chatchat.server.chat.utils import History
from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from chatchat.server.db.repository import add_message_to_db
@ -127,6 +126,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
stream=stream)
all_tools = get_tool().values()
tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
full_chain = create_models_chains(prompts=prompts,

View File

@ -7,6 +7,7 @@ import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from langchain_core.embeddings import Embeddings
from langchain.tools import BaseTool
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.llms import OpenAI
import httpx
@ -277,7 +278,7 @@ class BaseResponse(BaseModel):
data: Any = Field(None, description="API data")
class Config:
schema_extra = {
json_schema_extra = {
"example": {
"code": 200,
"msg": "success",
@ -289,7 +290,7 @@ class ListResponse(BaseResponse):
data: List[str] = Field(..., description="List of names")
class Config:
schema_extra = {
json_schema_extra = {
"example": {
"code": 200,
"msg": "success",
@ -307,7 +308,7 @@ class ChatMessage(BaseModel):
)
class Config:
schema_extra = {
json_schema_extra = {
"example": {
"question": "工伤保险如何办理?",
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n"
@ -696,3 +697,25 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
path = os.path.join(BASE_TEMP_DIR, id)
os.mkdir(path)
return path, id
def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]:
import importlib
from chatchat.server.agent import tools_factory
importlib.reload(tools_factory)
if name is None:
return tools_factory.tools_registry._TOOLS_REGISTRY
else:
return tools_factory.tools_registry._TOOLS_REGISTRY.get(name)
def get_tool_config(name: str = None) -> Dict:
import importlib
from chatchat.configs import model_config
importlib.reload(model_config)
if name is None:
return model_config.TOOL_CONFIG
else:
return model_config.TOOL_CONFIG.get(name, {})

View File

@ -1,5 +1,6 @@
import base64
from chatchat.server.utils import get_tool_config
import streamlit as st
from streamlit_antd_components.utils import ParseItems
@ -13,7 +14,7 @@ from datetime import datetime
import os
import re
import time
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS, TOOL_CONFIG)
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS)
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.utils import MsgType, get_config_models
import uuid
@ -157,12 +158,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
import importlib
importlib.reload(model_config_py)
tools = list(TOOL_CONFIG.keys())
tools = get_tool_config()
with st.expander("工具栏"):
for tool in tools:
is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool)
is_selected = st.checkbox(tool, value=tools[tool]["use"], key=tool)
if is_selected:
selected_tool_configs[tool] = TOOL_CONFIG[tool]
selected_tool_configs[tool] = tools[tool]
if llm_model is not None:
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})