From 9818bd2a88ba8414f3dc1812b6a0cca619e4c0d7 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Thu, 28 Mar 2024 13:08:51 +0800 Subject: [PATCH] =?UTF-8?q?-=20=E9=87=8D=E5=86=99=20tool=20=E9=83=A8?= =?UTF-8?q?=E5=88=86=EF=BC=9A=20(#3553)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 简化 tool 的定义方式 - 所有 tool 和 tool_config 支持热加载 - 修复:json_schema_extra warning --- .../chatchat/configs/basic_config.py.example | 10 +- .../chatchat/configs/kb_config.py.example | 4 +- .../chatchat/configs/model_config.py.example | 88 ++++++++++ .../chatchat/configs/prompt_config.py.example | 87 +--------- .../chatchat/configs/server_config.py.example | 1 - .../chatchat/server/agent/container.py | 21 ++- .../server/agent/tools_factory/__init__.py | 18 +- .../aqa.py => aqa_processor.py} | 14 +- .../server/agent/tools_factory/arxiv.py | 13 +- .../tools_factory/audio_factory/__init__.py | 1 - .../server/agent/tools_factory/calculate.py | 7 +- .../agent/tools_factory/search_internet.py | 23 +-- .../search_local_knowledgebase.py | 26 +-- .../agent/tools_factory/search_youtube.py | 12 +- .../server/agent/tools_factory/shell.py | 13 +- .../server/agent/tools_factory/text2image.py | 14 +- .../agent/tools_factory/tools_registry.py | 162 +++++++++++------- .../tools_factory/vision_factory/__init__.py | 1 - .../vqa.py => vqa_processor.py} | 16 +- .../agent/tools_factory/weather_check.py | 20 ++- .../server/agent/tools_factory/wolfram.py | 19 +- .../chatchat/server/api_server/tool_routes.py | 16 +- chatchat-server/chatchat/server/chat/chat.py | 4 +- chatchat-server/chatchat/server/utils.py | 29 +++- .../chatchat/webui_pages/dialogue/dialogue.py | 9 +- 25 files changed, 348 insertions(+), 280 deletions(-) rename chatchat-server/chatchat/server/agent/tools_factory/{audio_factory/aqa.py => aqa_processor.py} (74%) delete mode 100644 chatchat-server/chatchat/server/agent/tools_factory/audio_factory/__init__.py delete mode 100644 chatchat-server/chatchat/server/agent/tools_factory/vision_factory/__init__.py rename chatchat-server/chatchat/server/agent/tools_factory/{vision_factory/vqa.py => vqa_processor.py} (92%) diff --git a/chatchat-server/chatchat/configs/basic_config.py.example b/chatchat-server/chatchat/configs/basic_config.py.example index d18f0d57..f8cd6bbd 100644 --- a/chatchat-server/chatchat/configs/basic_config.py.example +++ b/chatchat-server/chatchat/configs/basic_config.py.example @@ -12,12 +12,12 @@ langchain.verbose = False # 通常情况下不需要更改以下内容 # 用户数据根目录 -DATA_PATH = (Path(__file__).absolute().parent.parent) # / "data") +DATA_PATH = str(Path(__file__).absolute().parent.parent / "data") if not os.path.exists(DATA_PATH): os.mkdir(DATA_PATH) # nltk 模型存储路径 -NLTK_DATA_PATH = os.path.join(DATA_PATH, "data/nltk_data") +NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data") import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -29,12 +29,12 @@ logging.basicConfig(format=LOG_FORMAT) # 日志存储路径 -LOG_PATH = os.path.join(DATA_PATH, "data/logs") +LOG_PATH = os.path.join(DATA_PATH, "logs") if not os.path.exists(LOG_PATH): os.mkdir(LOG_PATH) # 模型生成内容(图片、视频、音频等)保存位置 -MEDIA_PATH = os.path.join(DATA_PATH, "data/media") +MEDIA_PATH = os.path.join(DATA_PATH, "media") if not os.path.exists(MEDIA_PATH): os.mkdir(MEDIA_PATH) os.mkdir(os.path.join(MEDIA_PATH, "image")) @@ -42,6 +42,6 @@ if not os.path.exists(MEDIA_PATH): os.mkdir(os.path.join(MEDIA_PATH, "video")) # 临时文件目录,主要用于文件对话 -BASE_TEMP_DIR = os.path.join(DATA_PATH, "data/temp") +BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp") if not os.path.exists(BASE_TEMP_DIR): os.mkdir(BASE_TEMP_DIR) diff --git a/chatchat-server/chatchat/configs/kb_config.py.example b/chatchat-server/chatchat/configs/kb_config.py.example index 2d8d7f03..661407d3 100644 --- a/chatchat-server/chatchat/configs/kb_config.py.example +++ b/chatchat-server/chatchat/configs/kb_config.py.example @@ -1,6 +1,6 @@ import os -from chatchat.configs.basic_config import DATA_PATH +from .basic_config import DATA_PATH # 默认使用的知识库 @@ -51,7 +51,7 @@ KB_INFO = { # 通常情况下不需要更改以下内容 # 知识库默认存储路径 -KB_ROOT_PATH = os.path.join(DATA_PATH, "data/knowledge_base") +KB_ROOT_PATH = os.path.join(DATA_PATH, "knowledge_base") if not os.path.exists(KB_ROOT_PATH): os.mkdir(KB_ROOT_PATH) diff --git a/chatchat-server/chatchat/configs/model_config.py.example b/chatchat-server/chatchat/configs/model_config.py.example index c176e8c3..a38f1419 100644 --- a/chatchat-server/chatchat/configs/model_config.py.example +++ b/chatchat-server/chatchat/configs/model_config.py.example @@ -151,6 +151,7 @@ MODEL_PLATFORMS = [ "image_models": [], "multimodal_models": [], }, + { "platform_name": "ollama", "platform_type": "ollama", @@ -168,6 +169,7 @@ MODEL_PLATFORMS = [ "image_models": [], "multimodal_models": [], }, + # { # "platform_name": "loom", # "platform_type": "loom", @@ -184,3 +186,89 @@ MODEL_PLATFORMS = [ ] LOOM_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "loom.yaml") + +# 工具配置项 +TOOL_CONFIG = { + "search_local_knowledgebase": { + "use": False, + "top_k": 3, + "score_threshold": 1, + "conclude_prompt": { + "with_result": + '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' + '不允许在答案中添加编造成分,答案请使用中文。 \n' + '<已知信息>{{ context }}\n' + '<问题>{{ question }}\n', + "without_result": + '请你根据我的提问回答我的问题:\n' + '{{ question }}\n' + '请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n', + } + }, + "search_internet": { + "use": False, + "search_engine_name": "bing", + "search_engine_config": + { + "bing": { + "result_len": 3, + "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", + "bing_key": "", + }, + "metaphor": { + "result_len": 3, + "metaphor_api_key": "", + "split_result": False, + "chunk_size": 500, + "chunk_overlap": 0, + }, + "duckduckgo": { + "result_len": 3 + } + }, + "top_k": 10, + "verbose": "Origin", + "conclude_prompt": + "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " + "\n<已知信息>{{ context }}\n" + "<问题>\n" + "{{ question }}\n" + "\n" + }, + "arxiv": { + "use": False, + }, + "shell": { + "use": False, + }, + "weather_check": { + "use": False, + "api_key": "S8vrB4U_-c5mvAMiK", + }, + "search_youtube": { + "use": False, + }, + "wolfram": { + "use": False, + "appid": "", + }, + "calculate": { + "use": False, + }, + "vqa_processor": { + "use": False, + "model_path": "your model path", + "tokenizer_path": "your tokenizer path", + "device": "cuda:1" + }, + "aqa_processor": { + "use": False, + "model_path": "your model path", + "tokenizer_path": "yout tokenizer path", + "device": "cuda:2" + }, + "text2images": { + "use": False, + }, + +} diff --git a/chatchat-server/chatchat/configs/prompt_config.py.example b/chatchat-server/chatchat/configs/prompt_config.py.example index 2b2967c4..c9446583 100644 --- a/chatchat-server/chatchat/configs/prompt_config.py.example +++ b/chatchat-server/chatchat/configs/prompt_config.py.example @@ -83,7 +83,7 @@ PROMPT_TEMPLATES = { 'Question: {input}\n\n' '{agent_scratchpad}\n', "qwen": - 'Answer the following question as best you can. You have access to the following APIs:\n\n' + 'Answer the following questions as best you can. You have access to the following APIs:\n\n' '{tools}\n\n' 'Use the following format:\n\n' 'Question: the input question you must answer\n' @@ -122,88 +122,3 @@ PROMPT_TEMPLATES = { "default": "{{input}}", } } - -TOOL_CONFIG = { - "search_local_knowledgebase": { - "use": False, - "top_k": 3, - "score_threshold": 1, - "conclude_prompt": { - "with_result": - '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' - '不允许在答案中添加编造成分,答案请使用中文。 \n' - '<已知信息>{{ context }}\n' - '<问题>{{ question }}\n', - "without_result": - '请你根据我的提问回答我的问题:\n' - '{{ question }}\n' - '请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n', - } - }, - "search_internet": { - "use": False, - "search_engine_name": "bing", - "search_engine_config": - { - "bing": { - "result_len": 3, - "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", - "bing_key": "", - }, - "metaphor": { - "result_len": 3, - "metaphor_api_key": "", - "split_result": False, - "chunk_size": 500, - "chunk_overlap": 0, - }, - "duckduckgo": { - "result_len": 3 - } - }, - "top_k": 10, - "verbose": "Origin", - "conclude_prompt": - "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " - "\n<已知信息>{{ context }}\n" - "<问题>\n" - "{{ question }}\n" - "\n" - }, - "arxiv": { - "use": False, - }, - "shell": { - "use": False, - }, - "weather_check": { - "use": False, - "api-key": "S8vrB4U_-c5mvAMiK", - }, - "search_youtube": { - "use": False, - }, - "wolfram": { - "use": False, - }, - "calculate": { - "use": False, - }, - "vqa_processor": { - "use": False, - "model_path": "your model path", - "tokenizer_path": "your tokenizer path", - "device": "cuda:1" - }, - "aqa_processor": { - "use": False, - "model_path": "your model path", - "tokenizer_path": "yout tokenizer path", - "device": "cuda:2" - }, - - "text2images": { - "use": False, - }, - -} diff --git a/chatchat-server/chatchat/configs/server_config.py.example b/chatchat-server/chatchat/configs/server_config.py.example index 90182036..40485250 100644 --- a/chatchat-server/chatchat/configs/server_config.py.example +++ b/chatchat-server/chatchat/configs/server_config.py.example @@ -23,4 +23,3 @@ API_SERVER = { "host": DEFAULT_BIND_HOST, "port": 7861, } - diff --git a/chatchat-server/chatchat/server/agent/container.py b/chatchat-server/chatchat/server/agent/container.py index 369e5b2e..80217e93 100644 --- a/chatchat-server/chatchat/server/agent/container.py +++ b/chatchat-server/chatchat/server/agent/container.py @@ -1,4 +1,5 @@ -from chatchat.configs import TOOL_CONFIG, logger +from chatchat.configs import logger +from chatchat.server.utils import get_tool_config class ModelContainer: @@ -11,36 +12,38 @@ class ModelContainer: self.audio_tokenizer = None self.audio_model = None - if TOOL_CONFIG["vqa_processor"]["use"]: + vqa_config = get_tool_config("vqa_processor") + if vqa_config["use"]: try: from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer import torch self.vision_tokenizer = LlamaTokenizer.from_pretrained( - TOOL_CONFIG["vqa_processor"]["tokenizer_path"], + vqa_config["tokenizer_path"], trust_remote_code=True) self.vision_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"], + pretrained_model_name_or_path=vqa_config["model_path"], torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True - ).to(TOOL_CONFIG["vqa_processor"]["device"]).eval() + ).to(vqa_config["device"]).eval() except Exception as e: logger.error(e, exc_info=True) - if TOOL_CONFIG["aqa_processor"]["use"]: + aqa_config = get_tool_config("vqa_processor") + if aqa_config["use"]: try: from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer import torch self.audio_tokenizer = AutoTokenizer.from_pretrained( - TOOL_CONFIG["aqa_processor"]["tokenizer_path"], + aqa_config["tokenizer_path"], trust_remote_code=True ) self.audio_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"], + pretrained_model_name_or_path=aqa_config["model_path"], torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True).to( - TOOL_CONFIG["aqa_processor"]["device"] + aqa_config["device"] ).eval() except Exception as e: logger.error(e, exc_info=True) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/__init__.py b/chatchat-server/chatchat/server/agent/tools_factory/__init__.py index 6821511c..8faaedb7 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/__init__.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/__init__.py @@ -1,12 +1,12 @@ -from .search_local_knowledgebase import search_local_knowledgebase, SearchKnowledgeInput +from .search_local_knowledgebase import search_local_knowledgebase from .calculate import calculate -from .weather_check import weather_check, WeatherInput -from .shell import shell, ShellInput -from .search_internet import search_internet, SearchInternetInput -from .wolfram import wolfram, WolframInput -from .search_youtube import search_youtube, YoutubeInput -from .arxiv import arxiv, ArxivInput +from .weather_check import weather_check +from .shell import shell +from .search_internet import search_internet +from .wolfram import wolfram +from .search_youtube import search_youtube +from .arxiv import arxiv from .text2image import text2images -from .vision_factory import * -from .audio_factory import * +from .vqa_processor import vqa_processor +from .aqa_processor import aqa_processor diff --git a/chatchat-server/chatchat/server/agent/tools_factory/audio_factory/aqa.py b/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py similarity index 74% rename from chatchat-server/chatchat/server/agent/tools_factory/audio_factory/aqa.py rename to chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py index dee295fc..e5fec534 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/audio_factory/aqa.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/aqa_processor.py @@ -1,6 +1,8 @@ import base64 -import os -from chatchat.server.pydantic_v1 import BaseModel, Field +from chatchat.server.pydantic_v1 import Field +from chatchat.server.utils import get_tool_config +from .tools_registry import regist_tool + def save_base64_audio(base64_audio, file_path): audio_data = base64.b64decode(base64_audio) @@ -14,7 +16,10 @@ def aqa_run(model, tokenizer, query): return response -def aqa_processor(query: str): +@regist_tool +def aqa_processor(query: str = Field(description="The question of the image in English")): + '''use this tool to get answer for audio question''' + from chatchat.server.agent.container import container if container.metadata["audios"]: file_path = "temp_audio.mp3" @@ -26,6 +31,3 @@ def aqa_processor(query: str): return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model) else: return "No Audio, Please Try Again" - -class AQAInput(BaseModel): - query: str = Field(description="The question of the image in English") diff --git a/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py b/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py index 4d26f4b8..8c0bc4f4 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/arxiv.py @@ -1,11 +1,12 @@ # LangChain 的 ArxivQueryRun 工具 -from chatchat.server.pydantic_v1 import BaseModel, Field -from langchain.tools.arxiv.tool import ArxivQueryRun +from chatchat.server.pydantic_v1 import Field +from .tools_registry import regist_tool -def arxiv(query: str): +@regist_tool +def arxiv(query: str = Field(description="The search query title")): + '''A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.''' + from langchain.tools.arxiv.tool import ArxivQueryRun + tool = ArxivQueryRun() return tool.run(tool_input=query) - -class ArxivInput(BaseModel): - query: str = Field(description="The search query title") diff --git a/chatchat-server/chatchat/server/agent/tools_factory/audio_factory/__init__.py b/chatchat-server/chatchat/server/agent/tools_factory/audio_factory/__init__.py deleted file mode 100644 index 9f2332e2..00000000 --- a/chatchat-server/chatchat/server/agent/tools_factory/audio_factory/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .aqa import aqa_processor, AQAInput diff --git a/chatchat-server/chatchat/server/agent/tools_factory/calculate.py b/chatchat-server/chatchat/server/agent/tools_factory/calculate.py index 880aa5c8..caec1536 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/calculate.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/calculate.py @@ -1,8 +1,9 @@ -from langchain.agents import tool +from chatchat.server.pydantic_v1 import Field +from .tools_registry import regist_tool -@tool -def calculate(text: str) -> float: +@regist_tool +def calculate(text: str = Field(description="a math expression")) -> float: ''' Useful to answer questions about simple calculations. translate user question to a math expression that can be evaluated by numexpr. diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py b/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py index 951a3c9b..58624960 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py @@ -1,12 +1,15 @@ -from chatchat.server.pydantic_v1 import BaseModel, Field +from typing import List, Dict + from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper -from chatchat.configs import TOOL_CONFIG from langchain.text_splitter import RecursiveCharacterTextSplitter -from typing import List, Dict from langchain.docstore.document import Document -from strsimpy.normalized_levenshtein import NormalizedLevenshtein from markdownify import markdownify +from strsimpy.normalized_levenshtein import NormalizedLevenshtein + +from chatchat.server.utils import get_tool_config +from chatchat.server.pydantic_v1 import Field +from .tools_registry import regist_tool def bing_search(text, config): @@ -90,10 +93,10 @@ def search_engine(query: str, for doc in docs: context += doc + "\n" return context -def search_internet(query: str): - tool_config = TOOL_CONFIG["search_internet"] + + +@regist_tool +def search_internet(query: str = Field(description="query for Internet search")): + '''Use this tool to use bing search engine to search the internet and get information.''' + tool_config = get_tool_config("search_internet") return search_engine(query=query, config=tool_config) - -class SearchInternetInput(BaseModel): - query: str = Field(description="query for Internet search") - diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index 4e619791..22fdb35e 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -1,8 +1,14 @@ from urllib.parse import urlencode -from chatchat.server.pydantic_v1 import BaseModel, Field - +from chatchat.server.utils import get_tool_config +from chatchat.server.pydantic_v1 import Field +from .tools_registry import regist_tool from chatchat.server.knowledge_base.kb_doc_api import search_docs -from chatchat.configs import TOOL_CONFIG +from chatchat.configs import KB_INFO + + +template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." +KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) +template_knowledge = template.format(KB_info=KB_info_str, key="samples") def search_knowledgebase(query: str, database: str, config: dict): @@ -29,11 +35,11 @@ def search_knowledgebase(query: str, database: str, config: dict): return context -class SearchKnowledgeInput(BaseModel): - database: str = Field(description="Database for Knowledge Search") - query: str = Field(description="Query for Knowledge Search") - - -def search_local_knowledgebase(database: str, query: str): - tool_config = TOOL_CONFIG["search_local_knowledgebase"] +@regist_tool(description=template_knowledge) +def search_local_knowledgebase( + database: str = Field(description="Database for Knowledge Search"), + query: str = Field(description="Query for Knowledge Search"), +): + '''''' + tool_config = get_tool_config("search_local_knowledgebase") return search_knowledgebase(query=query, database=database, config=tool_config) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py b/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py index 50745f38..a706bef3 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_youtube.py @@ -1,10 +1,10 @@ -from langchain_community.tools import YouTubeSearchTool -from chatchat.server.pydantic_v1 import BaseModel, Field +from chatchat.server.pydantic_v1 import Field +from .tools_registry import regist_tool -def search_youtube(query: str): +@regist_tool +def search_youtube(query: str = Field(description="Query for Videos search")): + '''use this tools_factory to search youtube videos''' + from langchain_community.tools import YouTubeSearchTool tool = YouTubeSearchTool() return tool.run(tool_input=query) - -class YoutubeInput(BaseModel): - query: str = Field(description="Query for Videos search") diff --git a/chatchat-server/chatchat/server/agent/tools_factory/shell.py b/chatchat-server/chatchat/server/agent/tools_factory/shell.py index 6a5b8ced..efae7cd0 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/shell.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/shell.py @@ -1,11 +1,12 @@ # LangChain 的 Shell 工具 -from chatchat.server.pydantic_v1 import BaseModel, Field -from langchain_community.tools import ShellTool +from langchain.tools.shell import ShellTool + +from chatchat.server.pydantic_v1 import Field +from .tools_registry import regist_tool -def shell(query: str): +@regist_tool +def shell(query: str = Field(description="The command to execute")): + '''Use Shell to execute system shell commands''' tool = ShellTool() return tool.run(tool_input=query) - -class ShellInput(BaseModel): - query: str = Field(description="The command to execute") diff --git a/chatchat-server/chatchat/server/agent/tools_factory/text2image.py b/chatchat-server/chatchat/server/agent/tools_factory/text2image.py index 64802393..29afff24 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/text2image.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/text2image.py @@ -5,8 +5,9 @@ from PIL import Image from typing import List import uuid -from langchain.agents import tool -from chatchat.server.pydantic_v1 import Field, FieldInfo +from chatchat.server.pydantic_v1 import Field +from chatchat.server.utils import get_tool_config +from .tools_registry import regist_tool import openai from chatchat.configs.basic_config import MEDIA_PATH @@ -25,7 +26,7 @@ def get_image_model_config() -> dict: return config -@tool(return_direct=True) +@regist_tool(return_direct=True) def text2images( prompt: str, n: int = Field(1, description="需生成图片的数量"), @@ -33,13 +34,6 @@ def text2images( height: int = Field(512, description="生成图片的高度"), ) -> List[str]: '''根据用户的描述生成图片''' - # workaround before langchain uprading - if isinstance(n, FieldInfo): - n = n.default - if isinstance(width, FieldInfo): - width = width.default - if isinstance(height, FieldInfo): - height = height.default model_config = get_image_model_config() assert model_config is not None, "请正确配置文生图模型" diff --git a/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py b/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py index 3c552b6d..969e091d 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py @@ -1,69 +1,103 @@ -from langchain_core.tools import StructuredTool -from chatchat.server.agent.tools_factory import * -from chatchat.configs import KB_INFO +import re +from typing import Any, Union, Dict, Tuple, Callable, Optional, Type -template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." -KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) -template_knowledge = template.format(KB_info=KB_info_str, key="samples") +from langchain.agents import tool +from langchain_core.tools import BaseTool -all_tools = [ - calculate, - StructuredTool.from_function( - func=arxiv, - name="arxiv", - description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.", - args_schema=ArxivInput, - ), - StructuredTool.from_function( - func=shell, - name="shell", - description="Use Shell to execute Linux commands", - args_schema=ShellInput, - ), - StructuredTool.from_function( - func=wolfram, - name="wolfram", - description="Useful for when you need to calculate difficult formulas", - args_schema=WolframInput, +from chatchat.server.pydantic_v1 import BaseModel - ), - StructuredTool.from_function( - func=search_youtube, - name="search_youtube", - description="use this tools_factory to search youtube videos", - args_schema=YoutubeInput, - ), - StructuredTool.from_function( - func=weather_check, - name="weather_check", - description="Use this tool to check the weather at a specific location", - args_schema=WeatherInput, - ), - StructuredTool.from_function( - func=search_internet, - name="search_internet", - description="Use this tool to use bing search engine to search the internet and get information", - args_schema=SearchInternetInput, - ), - StructuredTool.from_function( - func=search_local_knowledgebase, - name="search_local_knowledgebase", - description=template_knowledge, - args_schema=SearchKnowledgeInput, - ), - StructuredTool.from_function( - func=vqa_processor, - name="vqa_processor", - description="use this tool to get answer for image question", - args_schema=VQAInput, - ), - StructuredTool.from_function( - func=aqa_processor, - name="aqa_processor", - description="use this tool to get answer for audio question", - args_schema=AQAInput, - ) -] +_TOOLS_REGISTRY = {} -all_tools.append(text2images) + +################################### TODO: workaround to langchain #15855 +# patch BaseTool to support tool parameters defined using pydantic Field + +def _new_parse_input( + self, + tool_input: Union[str, Dict], +) -> Union[str, Dict[str, Any]]: + """Convert tool input to pydantic model.""" + input_args = self.args_schema + if isinstance(tool_input, str): + if input_args is not None: + key_ = next(iter(input_args.__fields__.keys())) + input_args.validate({key_: tool_input}) + return tool_input + else: + if input_args is not None: + result = input_args.parse_obj(tool_input) + return result.dict() + + +def _new_to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + # For backwards compatibility, if run_input is a string, + # pass as a positional argument. + if isinstance(tool_input, str): + return (tool_input,), {} + else: + # for tool defined with `*args` parameters + # the args_schema has a field named `args` + # it should be expanded to actual *args + # e.g.: test_tools + # .test_named_tool_decorator_return_direct + # .search_api + if "args" in tool_input: + args = tool_input["args"] + if args is None: + tool_input.pop("args") + return (), tool_input + elif isinstance(args, tuple): + tool_input.pop("args") + return args, tool_input + return (), tool_input + + +BaseTool._parse_input = _new_parse_input +BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs +############################### + + +def regist_tool( + *args: Any, + description: str = "", + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + infer_schema: bool = True, +) -> Union[Callable, BaseTool]: + ''' + wrapper of langchain tool decorator + add tool to regstiry automatically + ''' + def _parse_tool(t: BaseTool): + nonlocal description + + _TOOLS_REGISTRY[t.name] = t + # change default description + if not description: + if t.func is not None: + description = t.func.__doc__ + elif t.coroutine is not None: + description = t.coroutine.__doc__ + t.description = " ".join(re.split(r"\n+\s*", description)) + + def wrapper(def_func: Callable) -> BaseTool: + partial_ = tool(*args, + return_direct=return_direct, + args_schema=args_schema, + infer_schema=infer_schema, + ) + t = partial_(def_func) + _parse_tool(t) + return t + + if len(args) == 0: + return wrapper + else: + t = tool(*args, + return_direct=return_direct, + args_schema=args_schema, + infer_schema=infer_schema, + ) + _parse_tool(t) + return t diff --git a/chatchat-server/chatchat/server/agent/tools_factory/vision_factory/__init__.py b/chatchat-server/chatchat/server/agent/tools_factory/vision_factory/__init__.py deleted file mode 100644 index ba5f1b34..00000000 --- a/chatchat-server/chatchat/server/agent/tools_factory/vision_factory/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .vqa import vqa_processor,VQAInput \ No newline at end of file diff --git a/chatchat-server/chatchat/server/agent/tools_factory/vision_factory/vqa.py b/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py similarity index 92% rename from chatchat-server/chatchat/server/agent/tools_factory/vision_factory/vqa.py rename to chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py index fb184f8b..13b965ea 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/vision_factory/vqa.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/vqa_processor.py @@ -4,8 +4,9 @@ Method Use cogagent to generate response for a given image and query. import base64 from io import BytesIO from PIL import Image, ImageDraw -from chatchat.server.pydantic_v1 import BaseModel, Field -from chatchat.configs import TOOL_CONFIG +from chatchat.server.pydantic_v1 import Field +from chatchat.server.utils import get_tool_config +from .tools_registry import regist_tool import re from chatchat.server.agent.container import container @@ -98,8 +99,11 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m return response -def vqa_processor(query: str): - tool_config = TOOL_CONFIG["vqa_processor"] +@regist_tool +def vqa_processor(query: str = Field(description="The question of the image in English")): + '''use this tool to get answer for image question''' + + tool_config = get_tool_config("vqa_processor") if container.metadata["images"]: image_base64 = container.metadata["images"][0] ans = vqa_run(model=container.vision_model, @@ -118,7 +122,3 @@ def vqa_processor(query: str): return ans else: return "No Image, Please Try Again" - - -class VQAInput(BaseModel): - query: str = Field(description="The question of the image in English") diff --git a/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py b/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py index a0785a5e..967264aa 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py @@ -1,11 +1,19 @@ """ 简单的单参数输入工具实现,用于查询现在天气的情况 """ -from chatchat.server.pydantic_v1 import BaseModel, Field +from chatchat.server.pydantic_v1 import Field +from chatchat.server.utils import get_tool_config +from .tools_registry import regist_tool import requests -def weather(location: str, api_key: str): - url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c" + +@regist_tool +def weather_check(city: str = Field(description="City name,include city and county,like '厦门'")): + '''Use this tool to check the weather at a specific city''' + + tool_config = get_tool_config("weather_check") + api_key = tool_config.get("api_key") + url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={city}&language=zh-Hans&unit=c" response = requests.get(url) if response.status_code == 200: data = response.json() @@ -17,9 +25,3 @@ def weather(location: str, api_key: str): else: raise Exception( f"Failed to retrieve weather: {response.status_code}") - - -def weather_check(location: str): - return weather(location, "S8vrB4U_-c5mvAMiK") -class WeatherInput(BaseModel): - location: str = Field(description="City name,include city and county,like '厦门'") diff --git a/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py b/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py index acd35f3e..8015ef76 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/wolfram.py @@ -1,13 +1,16 @@ # Langchain 自带的 Wolfram Alpha API 封装 -from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper -from chatchat.server.pydantic_v1 import BaseModel, Field -wolfram_alpha_appid = "your key" + +from chatchat.server.pydantic_v1 import Field +from chatchat.server.utils import get_tool_config +from .tools_registry import regist_tool -def wolfram(query: str): - wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid) +@regist_tool +def wolfram(query: str = Field(description="The formula to be calculated")): + '''Useful for when you need to calculate difficult formulas''' + + from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper + + wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=get_tool_config("wolfram").get("appid")) ans = wolfram.run(query) return ans - -class WolframInput(BaseModel): - formula: str = Field(description="The formula to be calculated") diff --git a/chatchat-server/chatchat/server/api_server/tool_routes.py b/chatchat-server/chatchat/server/api_server/tool_routes.py index c11aee2d..3fa68d86 100644 --- a/chatchat-server/chatchat/server/api_server/tool_routes.py +++ b/chatchat-server/chatchat/server/api_server/tool_routes.py @@ -5,7 +5,7 @@ from typing import List from fastapi import APIRouter, Request, Body from chatchat.configs import logger -from chatchat.server.utils import BaseResponse +from chatchat.server.utils import BaseResponse, get_tool tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) @@ -13,11 +13,8 @@ tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) @tool_router.get("/", response_model=BaseResponse) async def list_tools(): - import importlib - from chatchat.server.agent.tools_factory import tools_registry - importlib.reload(tools_registry) - - data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools_registry.all_tools} + tools = get_tool() + data = {t.name: {"name": t.name, "description": t.description, "args": t.args} for t in tools} return {"data": data} @@ -26,12 +23,9 @@ async def call_tool( name: str = Body(examples=["calculate"]), kwargs: dict = Body({}, examples=[{"a":1,"b":2,"operator":"+"}]), ): - import importlib - from chatchat.server.agent.tools_factory import tools_registry - importlib.reload(tools_registry) + tools = get_tool() - tool_names = {t.name: t for t in tools_registry.all_tools} - if tool := tool_names.get(name): + if tool := tools.get(name): try: result = await tool.ainvoke(kwargs) return {"data": result} diff --git a/chatchat-server/chatchat/server/chat/chat.py b/chatchat-server/chatchat/server/chat/chat.py index e506d621..e7a4dbad 100644 --- a/chatchat-server/chatchat/server/chat/chat.py +++ b/chatchat-server/chatchat/server/chat/chat.py @@ -11,10 +11,9 @@ from langchain.chains import LLMChain from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts import PromptTemplate from chatchat.server.agent.agent_factory.agents_registry import agents_registry -from chatchat.server.agent.tools_factory.tools_registry import all_tools from chatchat.server.agent.container import container -from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType +from chatchat.server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType, get_tool from chatchat.server.chat.utils import History from chatchat.server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from chatchat.server.db.repository import add_message_to_db @@ -127,6 +126,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callbacks = [callback] models, prompts = create_models_from_config(callbacks=callbacks, configs=chat_model_config, stream=stream) + all_tools = get_tool().values() tools = [tool for tool in all_tools if tool.name in tool_config] tools = [t.copy(update={"callbacks": callbacks}) for t in tools] full_chain = create_models_chains(prompts=prompts, diff --git a/chatchat-server/chatchat/server/utils.py b/chatchat-server/chatchat/server/utils.py index 24e83281..9db43145 100644 --- a/chatchat-server/chatchat/server/utils.py +++ b/chatchat-server/chatchat/server/utils.py @@ -7,6 +7,7 @@ import multiprocessing as mp from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed from langchain_core.embeddings import Embeddings +from langchain.tools import BaseTool from langchain_openai.chat_models import ChatOpenAI from langchain_openai.llms import OpenAI import httpx @@ -277,7 +278,7 @@ class BaseResponse(BaseModel): data: Any = Field(None, description="API data") class Config: - schema_extra = { + json_schema_extra = { "example": { "code": 200, "msg": "success", @@ -289,7 +290,7 @@ class ListResponse(BaseResponse): data: List[str] = Field(..., description="List of names") class Config: - schema_extra = { + json_schema_extra = { "example": { "code": 200, "msg": "success", @@ -307,7 +308,7 @@ class ChatMessage(BaseModel): ) class Config: - schema_extra = { + json_schema_extra = { "example": { "question": "工伤保险如何办理?", "response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n" @@ -696,3 +697,25 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]: path = os.path.join(BASE_TEMP_DIR, id) os.mkdir(path) return path, id + + +def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]: + import importlib + from chatchat.server.agent import tools_factory + importlib.reload(tools_factory) + + if name is None: + return tools_factory.tools_registry._TOOLS_REGISTRY + else: + return tools_factory.tools_registry._TOOLS_REGISTRY.get(name) + + +def get_tool_config(name: str = None) -> Dict: + import importlib + from chatchat.configs import model_config + importlib.reload(model_config) + + if name is None: + return model_config.TOOL_CONFIG + else: + return model_config.TOOL_CONFIG.get(name, {}) diff --git a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index 6bd9ecd0..52a5bf46 100644 --- a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -1,5 +1,6 @@ import base64 +from chatchat.server.utils import get_tool_config import streamlit as st from streamlit_antd_components.utils import ParseItems @@ -13,7 +14,7 @@ from datetime import datetime import os import re import time -from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS, TOOL_CONFIG) +from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS) from chatchat.server.callback_handler.agent_callback_handler import AgentStatus from chatchat.server.utils import MsgType, get_config_models import uuid @@ -157,12 +158,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): import importlib importlib.reload(model_config_py) - tools = list(TOOL_CONFIG.keys()) + tools = get_tool_config() with st.expander("工具栏"): for tool in tools: - is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool) + is_selected = st.checkbox(tool, value=tools[tool]["use"], key=tool) if is_selected: - selected_tool_configs[tool] = TOOL_CONFIG[tool] + selected_tool_configs[tool] = tools[tool] if llm_model is not None: chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})