- 重写 tool 部分: (#3553)

- 简化 tool 的定义方式
    - 所有 tool 和 tool_config 支持热加载
    - 修复:json_schema_extra warning
This commit is contained in:
liunux4odoo 2024-03-28 13:08:51 +08:00 committed by GitHub
parent f9f9d4b9fb
commit 9818bd2a88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 348 additions and 280 deletions

View File

@ -12,12 +12,12 @@ langchain.verbose = False
# 通常情况下不需要更改以下内容 # 通常情况下不需要更改以下内容
# 用户数据根目录 # 用户数据根目录
DATA_PATH = (Path(__file__).absolute().parent.parent) # / "data") DATA_PATH = str(Path(__file__).absolute().parent.parent / "data")
if not os.path.exists(DATA_PATH): if not os.path.exists(DATA_PATH):
os.mkdir(DATA_PATH) os.mkdir(DATA_PATH)
# nltk 模型存储路径 # nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(DATA_PATH, "data/nltk_data") NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
import nltk import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -29,12 +29,12 @@ logging.basicConfig(format=LOG_FORMAT)
# 日志存储路径 # 日志存储路径
LOG_PATH = os.path.join(DATA_PATH, "data/logs") LOG_PATH = os.path.join(DATA_PATH, "logs")
if not os.path.exists(LOG_PATH): if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH) os.mkdir(LOG_PATH)
# 模型生成内容(图片、视频、音频等)保存位置 # 模型生成内容(图片、视频、音频等)保存位置
MEDIA_PATH = os.path.join(DATA_PATH, "data/media") MEDIA_PATH = os.path.join(DATA_PATH, "media")
if not os.path.exists(MEDIA_PATH): if not os.path.exists(MEDIA_PATH):
os.mkdir(MEDIA_PATH) os.mkdir(MEDIA_PATH)
os.mkdir(os.path.join(MEDIA_PATH, "image")) os.mkdir(os.path.join(MEDIA_PATH, "image"))
@ -42,6 +42,6 @@ if not os.path.exists(MEDIA_PATH):
os.mkdir(os.path.join(MEDIA_PATH, "video")) os.mkdir(os.path.join(MEDIA_PATH, "video"))
# 临时文件目录,主要用于文件对话 # 临时文件目录,主要用于文件对话
BASE_TEMP_DIR = os.path.join(DATA_PATH, "data/temp") BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
if not os.path.exists(BASE_TEMP_DIR): if not os.path.exists(BASE_TEMP_DIR):
os.mkdir(BASE_TEMP_DIR) os.mkdir(BASE_TEMP_DIR)

View File

@ -1,6 +1,6 @@
import os import os
from chatchat.configs.basic_config import DATA_PATH from .basic_config import DATA_PATH
# 默认使用的知识库 # 默认使用的知识库
@ -51,7 +51,7 @@ KB_INFO = {
# 通常情况下不需要更改以下内容 # 通常情况下不需要更改以下内容
# 知识库默认存储路径 # 知识库默认存储路径
KB_ROOT_PATH = os.path.join(DATA_PATH, "data/knowledge_base") KB_ROOT_PATH = os.path.join(DATA_PATH, "knowledge_base")
if not os.path.exists(KB_ROOT_PATH): if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH) os.mkdir(KB_ROOT_PATH)

View File

@ -151,6 +151,7 @@ MODEL_PLATFORMS = [
"image_models": [], "image_models": [],
"multimodal_models": [], "multimodal_models": [],
}, },
{ {
"platform_name": "ollama", "platform_name": "ollama",
"platform_type": "ollama", "platform_type": "ollama",
@ -168,6 +169,7 @@ MODEL_PLATFORMS = [
"image_models": [], "image_models": [],
"multimodal_models": [], "multimodal_models": [],
}, },
# { # {
# "platform_name": "loom", # "platform_name": "loom",
# "platform_type": "loom", # "platform_type": "loom",
@ -184,3 +186,89 @@ MODEL_PLATFORMS = [
] ]
LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml") LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml")
# 工具配置项
TOOL_CONFIG = {
"search_local_knowledgebase": {
"use": False,
"top_k": 3,
"score_threshold": 1,
"conclude_prompt": {
"with_result":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"'
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"without_result":
'请你根据我的提问回答我的问题:\n'
'{{ question }}\n'
'请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n',
}
},
"search_internet": {
"use": False,
"search_engine_name": "bing",
"search_engine_config":
{
"bing": {
"result_len": 3,
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
"bing_key": "",
},
"metaphor": {
"result_len": 3,
"metaphor_api_key": "",
"split_result": False,
"chunk_size": 500,
"chunk_overlap": 0,
},
"duckduckgo": {
"result_len": 3
}
},
"top_k": 10,
"verbose": "Origin",
"conclude_prompt":
"<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
"</指令>\n<已知信息>{{ context }}</已知信息>\n"
"<问题>\n"
"{{ question }}\n"
"</问题>\n"
},
"arxiv": {
"use": False,
},
"shell": {
"use": False,
},
"weather_check": {
"use": False,
"api_key": "S8vrB4U_-c5mvAMiK",
},
"search_youtube": {
"use": False,
},
"wolfram": {
"use": False,
"appid": "",
},
"calculate": {
"use": False,
},
"vqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "your tokenizer path",
"device": "cuda:1"
},
"aqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "yout tokenizer path",
"device": "cuda:2"
},
"text2images": {
"use": False,
},
}

View File

@ -83,7 +83,7 @@ PROMPT_TEMPLATES = {
'Question: {input}\n\n' 'Question: {input}\n\n'
'{agent_scratchpad}\n', '{agent_scratchpad}\n',
"qwen": "qwen":
'Answer the following question as best you can. You have access to the following APIs:\n\n' 'Answer the following questions as best you can. You have access to the following APIs:\n\n'
'{tools}\n\n' '{tools}\n\n'
'Use the following format:\n\n' 'Use the following format:\n\n'
'Question: the input question you must answer\n' 'Question: the input question you must answer\n'
@ -122,88 +122,3 @@ PROMPT_TEMPLATES = {
"default": "{{input}}", "default": "{{input}}",
} }
} }
TOOL_CONFIG = {
"search_local_knowledgebase": {
"use": False,
"top_k": 3,
"score_threshold": 1,
"conclude_prompt": {
"with_result":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"'
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"without_result":
'请你根据我的提问回答我的问题:\n'
'{{ question }}\n'
'请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n',
}
},
"search_internet": {
"use": False,
"search_engine_name": "bing",
"search_engine_config":
{
"bing": {
"result_len": 3,
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
"bing_key": "",
},
"metaphor": {
"result_len": 3,
"metaphor_api_key": "",
"split_result": False,
"chunk_size": 500,
"chunk_overlap": 0,
},
"duckduckgo": {
"result_len": 3
}
},
"top_k": 10,
"verbose": "Origin",
"conclude_prompt":
"<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
"</指令>\n<已知信息>{{ context }}</已知信息>\n"
"<问题>\n"
"{{ question }}\n"
"</问题>\n"
},
"arxiv": {
"use": False,
},
"shell": {
"use": False,
},
"weather_check": {
"use": False,
"api-key": "S8vrB4U_-c5mvAMiK",
},
"search_youtube": {
"use": False,
},
"wolfram": {
"use": False,
},
"calculate": {
"use": False,
},
"vqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "your tokenizer path",
"device": "cuda:1"
},
"aqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "yout tokenizer path",
"device": "cuda:2"
},
"text2images": {
"use": False,
},
}

View File

@ -23,4 +23,3 @@ API_SERVER = {
"host": DEFAULT_BIND_HOST, "host": DEFAULT_BIND_HOST,
"port": 7861, "port": 7861,
} }

View File

@ -1,4 +1,5 @@
from chatchat.configs import TOOL_CONFIG, logger from chatchat.configs import logger
from chatchat.server.utils import get_tool_config
class ModelContainer: class ModelContainer:
@ -11,36 +12,38 @@ class ModelContainer:
self.audio_tokenizer = None self.audio_tokenizer = None
self.audio_model = None self.audio_model = None
if TOOL_CONFIG["vqa_processor"]["use"]: vqa_config = get_tool_config("vqa_processor")
if vqa_config["use"]:
try: try:
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
import torch import torch
self.vision_tokenizer = LlamaTokenizer.from_pretrained( self.vision_tokenizer = LlamaTokenizer.from_pretrained(
TOOL_CONFIG["vqa_processor"]["tokenizer_path"], vqa_config["tokenizer_path"],
trust_remote_code=True) trust_remote_code=True)
self.vision_model = AutoModelForCausalLM.from_pretrained( self.vision_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"], pretrained_model_name_or_path=vqa_config["model_path"],
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
trust_remote_code=True trust_remote_code=True
).to(TOOL_CONFIG["vqa_processor"]["device"]).eval() ).to(vqa_config["device"]).eval()
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
if TOOL_CONFIG["aqa_processor"]["use"]: aqa_config = get_tool_config("vqa_processor")
if aqa_config["use"]:
try: try:
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
import torch import torch
self.audio_tokenizer = AutoTokenizer.from_pretrained( self.audio_tokenizer = AutoTokenizer.from_pretrained(
TOOL_CONFIG["aqa_processor"]["tokenizer_path"], aqa_config["tokenizer_path"],
trust_remote_code=True trust_remote_code=True
) )
self.audio_model = AutoModelForCausalLM.from_pretrained( self.audio_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"], pretrained_model_name_or_path=aqa_config["model_path"],
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
trust_remote_code=True).to( trust_remote_code=True).to(
TOOL_CONFIG["aqa_processor"]["device"] aqa_config["device"]
).eval() ).eval()
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)

View File

@ -1,12 +1,12 @@
from .search_local_knowledgebase import search_local_knowledgebase, SearchKnowledgeInput from .search_local_knowledgebase import search_local_knowledgebase
from .calculate import calculate from .calculate import calculate
from .weather_check import weather_check, WeatherInput from .weather_check import weather_check
from .shell import shell, ShellInput from .shell import shell
from .search_internet import search_internet, SearchInternetInput from .search_internet import search_internet
from .wolfram import wolfram, WolframInput from .wolfram import wolfram
from .search_youtube import search_youtube, YoutubeInput from .search_youtube import search_youtube
from .arxiv import arxiv, ArxivInput from .arxiv import arxiv
from .text2image import text2images from .text2image import text2images
from .vision_factory import * from .vqa_processor import vqa_processor
from .audio_factory import * from .aqa_processor import aqa_processor

View File

@ -1,6 +1,8 @@
import base64 import base64
import os from chatchat.server.pydantic_v1 import Field
from chatchat.server.pydantic_v1 import BaseModel, Field from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
def save_base64_audio(base64_audio, file_path): def save_base64_audio(base64_audio, file_path):
audio_data = base64.b64decode(base64_audio) audio_data = base64.b64decode(base64_audio)
@ -14,7 +16,10 @@ def aqa_run(model, tokenizer, query):
return response return response
def aqa_processor(query: str): @regist_tool
def aqa_processor(query: str = Field(description="The question of the image in English")):
'''use this tool to get answer for audio question'''
from chatchat.server.agent.container import container from chatchat.server.agent.container import container
if container.metadata["audios"]: if container.metadata["audios"]:
file_path = "temp_audio.mp3" file_path = "temp_audio.mp3"
@ -26,6 +31,3 @@ def aqa_processor(query: str):
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model) return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
else: else:
return "No Audio, Please Try Again" return "No Audio, Please Try Again"
class AQAInput(BaseModel):
query: str = Field(description="The question of the image in English")

View File

@ -1,11 +1,12 @@
# LangChain 的 ArxivQueryRun 工具 # LangChain 的 ArxivQueryRun 工具
from chatchat.server.pydantic_v1 import BaseModel, Field from chatchat.server.pydantic_v1 import Field
from langchain.tools.arxiv.tool import ArxivQueryRun from .tools_registry import regist_tool
def arxiv(query: str): @regist_tool
def arxiv(query: str = Field(description="The search query title")):
'''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.'''
from langchain.tools.arxiv.tool import ArxivQueryRun
tool = ArxivQueryRun() tool = ArxivQueryRun()
return tool.run(tool_input=query) return tool.run(tool_input=query)
class ArxivInput(BaseModel):
query: str = Field(description="The search query title")

View File

@ -1 +0,0 @@
from .aqa import aqa_processor, AQAInput

View File

@ -1,8 +1,9 @@
from langchain.agents import tool from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
@tool @regist_tool
def calculate(text: str) -> float: def calculate(text: str = Field(description="a math expression")) -> float:
''' '''
Useful to answer questions about simple calculations. Useful to answer questions about simple calculations.
translate user question to a math expression that can be evaluated by numexpr. translate user question to a math expression that can be evaluated by numexpr.

View File

@ -1,12 +1,15 @@
from chatchat.server.pydantic_v1 import BaseModel, Field from typing import List, Dict
from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from chatchat.configs import TOOL_CONFIG
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Dict
from langchain.docstore.document import Document from langchain.docstore.document import Document
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify from markdownify import markdownify
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
def bing_search(text, config): def bing_search(text, config):
@ -90,10 +93,10 @@ def search_engine(query: str,
for doc in docs: for doc in docs:
context += doc + "\n" context += doc + "\n"
return context return context
def search_internet(query: str):
tool_config = TOOL_CONFIG["search_internet"]
@regist_tool
def search_internet(query: str = Field(description="query for Internet search")):
'''Use this tool to use bing search engine to search the internet and get information.'''
tool_config = get_tool_config("search_internet")
return search_engine(query=query, config=tool_config) return search_engine(query=query, config=tool_config)
class SearchInternetInput(BaseModel):
query: str = Field(description="query for Internet search")

View File

@ -1,8 +1,14 @@
from urllib.parse import urlencode from urllib.parse import urlencode
from chatchat.server.pydantic_v1 import BaseModel, Field from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
from chatchat.server.knowledge_base.kb_doc_api import search_docs from chatchat.server.knowledge_base.kb_doc_api import search_docs
from chatchat.configs import TOOL_CONFIG from chatchat.configs import KB_INFO
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
def search_knowledgebase(query: str, database: str, config: dict): def search_knowledgebase(query: str, database: str, config: dict):
@ -29,11 +35,11 @@ def search_knowledgebase(query: str, database: str, config: dict):
return context return context
class SearchKnowledgeInput(BaseModel): @regist_tool(description=template_knowledge)
database: str = Field(description="Database for Knowledge Search") def search_local_knowledgebase(
query: str = Field(description="Query for Knowledge Search") database: str = Field(description="Database for Knowledge Search"),
query: str = Field(description="Query for Knowledge Search"),
):
def search_local_knowledgebase(database: str, query: str): ''''''
tool_config = TOOL_CONFIG["search_local_knowledgebase"] tool_config = get_tool_config("search_local_knowledgebase")
return search_knowledgebase(query=query, database=database, config=tool_config) return search_knowledgebase(query=query, database=database, config=tool_config)

View File

@ -1,10 +1,10 @@
from langchain_community.tools import YouTubeSearchTool from chatchat.server.pydantic_v1 import Field
from chatchat.server.pydantic_v1 import BaseModel, Field from .tools_registry import regist_tool
def search_youtube(query: str): @regist_tool
def search_youtube(query: str = Field(description="Query for Videos search")):
'''use this tools_factory to search youtube videos'''
from langchain_community.tools import YouTubeSearchTool
tool = YouTubeSearchTool() tool = YouTubeSearchTool()
return tool.run(tool_input=query) return tool.run(tool_input=query)
class YoutubeInput(BaseModel):
query: str = Field(description="Query for Videos search")

View File

@ -1,11 +1,12 @@
# LangChain 的 Shell 工具 # LangChain 的 Shell 工具
from chatchat.server.pydantic_v1 import BaseModel, Field from langchain.tools.shell import ShellTool
from langchain_community.tools import ShellTool
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool
def shell(query: str): @regist_tool
def shell(query: str = Field(description="The command to execute")):
'''Use Shell to execute system shell commands'''
tool = ShellTool() tool = ShellTool()
return tool.run(tool_input=query) return tool.run(tool_input=query)
class ShellInput(BaseModel):
query: str = Field(description="The command to execute")

View File

@ -5,8 +5,9 @@ from PIL import Image
from typing import List from typing import List
import uuid import uuid
from langchain.agents import tool from chatchat.server.pydantic_v1 import Field
from chatchat.server.pydantic_v1 import Field, FieldInfo from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
import openai import openai
from chatchat.configs.basic_config import MEDIA_PATH from chatchat.configs.basic_config import MEDIA_PATH
@ -25,7 +26,7 @@ def get_image_model_config() -> dict:
return config return config
@tool(return_direct=True) @regist_tool(return_direct=True)
def text2images( def text2images(
prompt: str, prompt: str,
n: int = Field(1, description="需生成图片的数量"), n: int = Field(1, description="需生成图片的数量"),
@ -33,13 +34,6 @@ def text2images(
height: int = Field(512, description="生成图片的高度"), height: int = Field(512, description="生成图片的高度"),
) -> List[str]: ) -> List[str]:
'''根据用户的描述生成图片''' '''根据用户的描述生成图片'''
# workaround before langchain uprading
if isinstance(n, FieldInfo):
n = n.default
if isinstance(width, FieldInfo):
width = width.default
if isinstance(height, FieldInfo):
height = height.default
model_config = get_image_model_config() model_config = get_image_model_config()
assert model_config is not None, "请正确配置文生图模型" assert model_config is not None, "请正确配置文生图模型"

View File

@ -1,69 +1,103 @@
from langchain_core.tools import StructuredTool import re
from chatchat.server.agent.tools_factory import * from typing import Any, Union, Dict, Tuple, Callable, Optional, Type
from chatchat.configs import KB_INFO
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." from langchain.agents import tool
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) from langchain_core.tools import BaseTool
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
all_tools = [ from chatchat.server.pydantic_v1 import BaseModel
calculate,
StructuredTool.from_function(
func=arxiv,
name="arxiv",
description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
args_schema=ArxivInput,
),
StructuredTool.from_function(
func=shell,
name="shell",
description="Use Shell to execute Linux commands",
args_schema=ShellInput,
),
StructuredTool.from_function(
func=wolfram,
name="wolfram",
description="Useful for when you need to calculate difficult formulas",
args_schema=WolframInput,
),
StructuredTool.from_function(
func=search_youtube,
name="search_youtube",
description="use this tools_factory to search youtube videos",
args_schema=YoutubeInput,
),
StructuredTool.from_function(
func=weather_check,
name="weather_check",
description="Use this tool to check the weather at a specific location",
args_schema=WeatherInput,
),
StructuredTool.from_function(
func=search_internet,
name="search_internet",
description="Use this tool to use bing search engine to search the internet and get information",
args_schema=SearchInternetInput,
),
StructuredTool.from_function(
func=search_local_knowledgebase,
name="search_local_knowledgebase",
description=template_knowledge,
args_schema=SearchKnowledgeInput,
),
StructuredTool.from_function(
func=vqa_processor,
name="vqa_processor",
description="use this tool to get answer for image question",
args_schema=VQAInput,
),
StructuredTool.from_function(
func=aqa_processor,
name="aqa_processor",
description="use this tool to get answer for audio question",
args_schema=AQAInput,
) _TOOLS_REGISTRY = {}
]
all_tools.append(text2images)
################################### TODO: workaround to langchain #15855
# patch BaseTool to support tool parameters defined using pydantic Field
def _new_parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
return tool_input
else:
if input_args is not None:
result = input_args.parse_obj(tool_input)
return result.dict()
def _new_to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
# for tool defined with `*args` parameters
# the args_schema has a field named `args`
# it should be expanded to actual *args
# e.g.: test_tools
# .test_named_tool_decorator_return_direct
# .search_api
if "args" in tool_input:
args = tool_input["args"]
if args is None:
tool_input.pop("args")
return (), tool_input
elif isinstance(args, tuple):
tool_input.pop("args")
return args, tool_input
return (), tool_input
BaseTool._parse_input = _new_parse_input
BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
###############################
def regist_tool(
*args: Any,
description: str = "",
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Union[Callable, BaseTool]:
'''
wrapper of langchain tool decorator
add tool to regstiry automatically
'''
def _parse_tool(t: BaseTool):
nonlocal description
_TOOLS_REGISTRY[t.name] = t
# change default description
if not description:
if t.func is not None:
description = t.func.__doc__
elif t.coroutine is not None:
description = t.coroutine.__doc__
t.description = " ".join(re.split(r"\n+\s*", description))
def wrapper(def_func: Callable) -> BaseTool:
partial_ = tool(*args,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
t = partial_(def_func)
_parse_tool(t)
return t
if len(args) == 0:
return wrapper
else:
t = tool(*args,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
_parse_tool(t)
return t

View File

@ -1 +0,0 @@
from .vqa import vqa_processor,VQAInput

View File

@ -4,8 +4,9 @@ Method Use cogagent to generate response for a given image and query.
import base64 import base64
from io import BytesIO from io import BytesIO
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from chatchat.server.pydantic_v1 import BaseModel, Field from chatchat.server.pydantic_v1 import Field
from chatchat.configs import TOOL_CONFIG from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
import re import re
from chatchat.server.agent.container import container from chatchat.server.agent.container import container
@ -98,8 +99,11 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
return response return response
def vqa_processor(query: str): @regist_tool
tool_config = TOOL_CONFIG["vqa_processor"] def vqa_processor(query: str = Field(description="The question of the image in English")):
'''use this tool to get answer for image question'''
tool_config = get_tool_config("vqa_processor")
if container.metadata["images"]: if container.metadata["images"]:
image_base64 = container.metadata["images"][0] image_base64 = container.metadata["images"][0]
ans = vqa_run(model=container.vision_model, ans = vqa_run(model=container.vision_model,
@ -118,7 +122,3 @@ def vqa_processor(query: str):
return ans return ans
else: else:
return "No Image, Please Try Again" return "No Image, Please Try Again"
class VQAInput(BaseModel):
query: str = Field(description="The question of the image in English")

View File

@ -1,11 +1,19 @@
""" """
简单的单参数输入工具实现用于查询现在天气的情况 简单的单参数输入工具实现用于查询现在天气的情况
""" """
from chatchat.server.pydantic_v1 import BaseModel, Field from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
import requests import requests
def weather(location: str, api_key: str):
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c" @regist_tool
def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")):
'''Use this tool to check the weather at a specific city'''
tool_config = get_tool_config("weather_check")
api_key = tool_config.get("api_key")
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={city}&language=zh-Hans&unit=c"
response = requests.get(url) response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
@ -17,9 +25,3 @@ def weather(location: str, api_key: str):
else: else:
raise Exception( raise Exception(
f"Failed to retrieve weather: {response.status_code}") f"Failed to retrieve weather: {response.status_code}")
def weather_check(location: str):
return weather(location, "S8vrB4U_-c5mvAMiK")
class WeatherInput(BaseModel):
location: str = Field(description="City name,include city and county,like '厦门'")

View File

@ -1,13 +1,16 @@
# Langchain 自带的 Wolfram Alpha API 封装 # Langchain 自带的 Wolfram Alpha API 封装
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from chatchat.server.pydantic_v1 import BaseModel, Field from chatchat.server.pydantic_v1 import Field
wolfram_alpha_appid = "your key" from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool
def wolfram(query: str): @regist_tool
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid) def wolfram(query: str = Field(description="The formula to be calculated")):
'''Useful for when you need to calculate difficult formulas'''
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid"))
ans = wolfram.run(query) ans = wolfram.run(query)
return ans return ans
class WolframInput(BaseModel):
formula: str = Field(description="The formula to be calculated")

View File

@ -5,7 +5,7 @@ from typing import List
from fastapi import APIRouter, Request, Body from fastapi import APIRouter, Request, Body
from chatchat.configs import logger from chatchat.configs import logger
from chatchat.server.utils import BaseResponse from chatchat.server.utils import BaseResponse, get_tool
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@ -13,11 +13,8 @@ tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@tool_router.get("/", response_model=BaseResponse) @tool_router.get("/", response_model=BaseResponse)
async def list_tools(): async def list_tools():
import importlib tools = get_tool()
from chatchat.server.agent.tools_factory import tools_registry data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools}
importlib.reload(tools_registry)
data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools_registry.all_tools}
return {"data": data} return {"data": data}
@ -26,12 +23,9 @@ async def call_tool(
name: str = Body(examples=["calculate"]), name: str = Body(examples=["calculate"]),
kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]), kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]),
): ):
import importlib tools = get_tool()
from chatchat.server.agent.tools_factory import tools_registry
importlib.reload(tools_registry)
tool_names = {t.name: t for t in tools_registry.all_tools} if tool := tools.get(name):
if tool := tool_names.get(name):
try: try:
result = await tool.ainvoke(kwargs) result = await tool.ainvoke(kwargs)
return {"data": result} return {"data": result}

View File

@ -11,10 +11,9 @@ from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from chatchat.server.agent.agent_factory.agents_registry import agents_registry from chatchat.server.agent.agent_factory.agents_registry import agents_registry
from chatchat.server.agent.tools_factory.tools_registry import all_tools
from chatchat.server.agent.container import container from chatchat.server.agent.container import container
from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool
from chatchat.server.chat.utils import History from chatchat.server.chat.utils import History
from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from chatchat.server.db.repository import add_message_to_db from chatchat.server.db.repository import add_message_to_db
@ -127,6 +126,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks = [callback] callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config, models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config,
stream=stream) stream=stream)
all_tools = get_tool().values()
tools = [tool for tool in all_tools if tool.name in tool_config] tools = [tool for tool in all_tools if tool.name in tool_config]
tools = [t.copy(update={"callbacks": callbacks}) for t in tools] tools = [t.copy(update={"callbacks": callbacks}) for t in tools]
full_chain = create_models_chains(prompts=prompts, full_chain = create_models_chains(prompts=prompts,

View File

@ -7,6 +7,7 @@ import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain.tools import BaseTool
from langchain_openai.chat_models import ChatOpenAI from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.llms import OpenAI from langchain_openai.llms import OpenAI
import httpx import httpx
@ -277,7 +278,7 @@ class BaseResponse(BaseModel):
data: Any = Field(None, description="API data") data: Any = Field(None, description="API data")
class Config: class Config:
schema_extra = { json_schema_extra = {
"example": { "example": {
"code": 200, "code": 200,
"msg": "success", "msg": "success",
@ -289,7 +290,7 @@ class ListResponse(BaseResponse):
data: List[str] = Field(..., description="List of names") data: List[str] = Field(..., description="List of names")
class Config: class Config:
schema_extra = { json_schema_extra = {
"example": { "example": {
"code": 200, "code": 200,
"msg": "success", "msg": "success",
@ -307,7 +308,7 @@ class ChatMessage(BaseModel):
) )
class Config: class Config:
schema_extra = { json_schema_extra = {
"example": { "example": {
"question": "工伤保险如何办理?", "question": "工伤保险如何办理?",
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n" "response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n"
@ -696,3 +697,25 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
path = os.path.join(BASE_TEMP_DIR, id) path = os.path.join(BASE_TEMP_DIR, id)
os.mkdir(path) os.mkdir(path)
return path, id return path, id
def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]:
import importlib
from chatchat.server.agent import tools_factory
importlib.reload(tools_factory)
if name is None:
return tools_factory.tools_registry._TOOLS_REGISTRY
else:
return tools_factory.tools_registry._TOOLS_REGISTRY.get(name)
def get_tool_config(name: str = None) -> Dict:
import importlib
from chatchat.configs import model_config
importlib.reload(model_config)
if name is None:
return model_config.TOOL_CONFIG
else:
return model_config.TOOL_CONFIG.get(name, {})

View File

@ -1,5 +1,6 @@
import base64 import base64
from chatchat.server.utils import get_tool_config
import streamlit as st import streamlit as st
from streamlit_antd_components.utils import ParseItems from streamlit_antd_components.utils import ParseItems
@ -13,7 +14,7 @@ from datetime import datetime
import os import os
import re import re
import time import time
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS, TOOL_CONFIG) from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS)
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.utils import MsgType, get_config_models from chatchat.server.utils import MsgType, get_config_models
import uuid import uuid
@ -157,12 +158,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
import importlib import importlib
importlib.reload(model_config_py) importlib.reload(model_config_py)
tools = list(TOOL_CONFIG.keys()) tools = get_tool_config()
with st.expander("工具栏"): with st.expander("工具栏"):
for tool in tools: for tool in tools:
is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool) is_selected = st.checkbox(tool, value=tools[tool]["use"], key=tool)
if is_selected: if is_selected:
selected_tool_configs[tool] = TOOL_CONFIG[tool] selected_tool_configs[tool] = tools[tool]
if llm_model is not None: if llm_model is not None:
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {}) chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})