From 598eb298dfd8e8bf0eb9f09a08d5671b08418d2c Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sun, 17 Sep 2023 11:19:16 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AC=AC=E4=B8=80=E7=89=88=E5=88=9D=E6=AD=A5ag?= =?UTF-8?q?ent=E5=AE=9E=E7=8E=B0=20(#1503)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 第一版初步agent实现 * 增加steaming参数 * 修改了weather.py --------- Co-authored-by: zR --- docs/自定义Agent.md | 73 ++++++ server/agent/__init__.py | 0 server/agent/custom_template.py | 104 +++++++++ server/agent/math.py | 70 ++++++ server/agent/tools.py | 28 +++ server/agent/translator.py | 41 ++++ server/agent/weather.py | 355 +++++++++++++++++++++++++++++ server/api.py | 32 +-- server/chat/__init__.py | 1 + server/chat/agent_chat.py | 73 ++++++ server/chat/utils.py | 3 +- tests/agent/test_agent_function.py | 40 ++++ webui_pages/dialogue/dialogue.py | 14 ++ webui_pages/utils.py | 33 +++ 14 files changed, 852 insertions(+), 15 deletions(-) create mode 100644 docs/自定义Agent.md create mode 100644 server/agent/__init__.py create mode 100644 server/agent/custom_template.py create mode 100644 server/agent/math.py create mode 100644 server/agent/tools.py create mode 100644 server/agent/translator.py create mode 100644 server/agent/weather.py create mode 100644 server/chat/agent_chat.py create mode 100644 tests/agent/test_agent_function.py diff --git a/docs/自定义Agent.md b/docs/自定义Agent.md new file mode 100644 index 00000000..68bf54da --- /dev/null +++ b/docs/自定义Agent.md @@ -0,0 +1,73 @@ +## 自定义属于自己的Agent +### 1. 创建自己的Agent的py文件 +开发者在```server/agent```文件中创建一个自己的文件,并将其添加到```tools.py```中。 + +例如,您创建了一个```custom_agent.py```文件,其中包含一个```work```函数,那么您需要在```tools.py```中添加如下代码: +```python +from custom_agent import work +Tool.from_function( + func=work, + name="该函数的名字", + description="" + ) +``` + +### 2. 修改 custom_template.py文件 +开发者需要根据自己选择的大模型设定适合该模型的Agent Prompt和自自定义返回格式。 +在我们的代码中,提供了默认的两种方式,一种是适配于GPT的提示词: +```python +template = """Answer the following questions as best you can, You have access to the following tools: +{tools} +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{tool_names}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can repeat N times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin! + +Previous conversation history: +{history} + +New question: {input} +{agent_scratchpad}""" +``` +另一种是适配于GLM-130B的提示词: +```python +template = """ +尽可能地回答以下问题。你可以使用以下工具:{tools} +请按照以下格式进行: +Question: 需要你回答的输入问题 +Thought: 你应该总是思考该做什么 +Action: 需要使用的工具,应该是[{tool_names}]中的一个 +Action Input: 传入工具的内容 +Observation: 行动的结果 + ... (这个Thought/Action/Action Input/Observation可以重复N次) +Thought: 我现在知道最后的答案 +Final Answer: 对原始输入问题的最终答案 + +现在开始! + +之前的对话: +{history} + +New question: {input} +Thought: {agent_scratchpad}""" +``` + +### 3. 局限性 +1. 在我们的实验中,小于70B级别的模型,若不经过微调,很难达到较好的效果。因此,我们建议开发者使用大于70B级别的模型进行微调,以达到更好的效果。 +2. 由于Agent的脆弱性,temperture参数的设置对于模型的效果有很大的影响。我们建议开发者在使用自定义Agent时,对于不同的模型,将其设置成0.1以下,以达到更好的效果。 +3. 即使使用了大于70B级别的模型,开发者也应该在Prompt上进行深度优化,以让模型能成功的选择工具并完成任务。 + + +### 4. 我们已经支持的Agent +我们为开发者编写了三个运用大模型执行的Agent,分别是: +1. 翻译工具,实现对输入的任意语言翻译。 +2. 数学工具,使用LLMMathChain 实现数学计算。 +3. 天气工具,使用自定义的LLMWetherChain实现天气查询,调用和风天气API。 \ No newline at end of file diff --git a/server/agent/__init__.py b/server/agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/agent/custom_template.py b/server/agent/custom_template.py new file mode 100644 index 00000000..9f3532fb --- /dev/null +++ b/server/agent/custom_template.py @@ -0,0 +1,104 @@ +template = """ +尽可能地回答以下问题。你可以使用以下工具:{tools} +请按照以下格式进行: +Question: 需要你回答的输入问题 +Thought: 你应该总是思考该做什么 +Action: 需要使用的工具,应该是[{tool_names}]中的一个 +Action Input: 传入工具的内容 +Observation: 行动的结果 + ... (这个Thought/Action/Action Input/Observation可以重复N次) +Thought: 我现在知道最后的答案 +Final Answer: 对原始输入问题的最终答案 + +现在开始! + +之前的对话: +{history} + +New question: {input} +Thought: {agent_scratchpad}""" + + +# ChatGPT 提示词模板 +# template = """Answer the following questions as best you can, You have access to the following tools: +# {tools} +# Use the following format: +# +# Question: the input question you must answer +# Thought: you should always think about what to do +# Action: the action to take, should be one of [{tool_names}] +# Action Input: the input to the action +# Observation: the result of the action +# ... (this Thought/Action/Action Input/Observation can repeat N times) +# Thought: I now know the final answer +# Final Answer: the final answer to the original input question +# +# Begin! +# +# Previous conversation history: +# {history} +# +# New question: {input} +# {agent_scratchpad}""" + + +from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser +from langchain.prompts import StringPromptTemplate +from langchain import OpenAI, SerpAPIWrapper, LLMChain +from typing import List, Union +from langchain.schema import AgentAction, AgentFinish, OutputParserException +from server.agent.tools import tools +import re +class CustomPromptTemplate(StringPromptTemplate): + # The template to use + template: str + # The list of tools available + tools: List[Tool] + + def format(self, **kwargs) -> str: + # Get the intermediate steps (AgentAction, Observation tuples) + # Format them in a particular way + intermediate_steps = kwargs.pop("intermediate_steps") + thoughts = "" + for action, observation in intermediate_steps: + thoughts += action.log + thoughts += f"\nObservation: {observation}\nThought: " + # Set the agent_scratchpad variable to that value + kwargs["agent_scratchpad"] = thoughts + # Create a tools variable from the list of tools provided + kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) + # Create a list of tool names for the tools provided + kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) + return self.template.format(**kwargs) + +prompt = CustomPromptTemplate( + template=template, + tools=tools, + input_variables=["input", "intermediate_steps", "history"] +) +class CustomOutputParser(AgentOutputParser): + + def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: + # Check if agent should finish + if "Final Answer:" in llm_output: + return AgentFinish( + # Return values is generally always a dictionary with a single `output` key + # It is not recommended to try anything else at the moment :) + return_values={"output": llm_output.split("Final Answer:")[-1].strip()}, + log=llm_output, + ) + # Parse out the action and action input + regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" + match = re.search(regex, llm_output, re.DOTALL) + if not match: + return AgentFinish( + return_values={"output": f"调用agent失败: `{llm_output}`"}, + log=llm_output, + ) + raise OutputParserException(f"调用agent失败: `{llm_output}`") + action = match.group(1).strip() + action_input = match.group(2) + # Return the action and action input + return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output) + + diff --git a/server/agent/math.py b/server/agent/math.py new file mode 100644 index 00000000..c865125c --- /dev/null +++ b/server/agent/math.py @@ -0,0 +1,70 @@ +from langchain import PromptTemplate +from langchain.chains import LLMMathChain +from server.chat.utils import wrap_done, get_ChatOpenAI +from configs.model_config import LLM_MODEL, TEMPERATURE +from langchain.chat_models import ChatOpenAI +from langchain.callbacks.manager import CallbackManagerForToolRun + +_PROMPT_TEMPLATE = """将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。 +问题: ${{包含数学问题的问题。}} +```text +${{解决问题的单行数学表达式}} +``` +...numexpr.evaluate(query)... +```output +${{运行代码的输出}} +``` +答案: ${{答案}} + +这是两个例子: + +问题: 37593 * 67是多少? +```text +37593 * 67 +``` +...numexpr.evaluate("37593 * 67")... +```output +2518731 + +答案: 2518731 + +问题: 37593的五次方根是多少? +```text +37593**(1/5) +``` +...numexpr.evaluate("37593**(1/5)")... +```output +8.222831614237718 + +答案: 8.222831614237718 + + +问题: 2的平方是多少? +```text +2 ** 2 +``` +...numexpr.evaluate("2 ** 2")... +```output +4 + +答案: 4 + + +现在,这是我的问题: +问题: {question} +""" +PROMPT = PromptTemplate( + input_variables=["question"], + template=_PROMPT_TEMPLATE, +) + + +def calculate(query: str): + model = get_ChatOpenAI( + streaming=False, + model_name=LLM_MODEL, + temperature=TEMPERATURE, + ) + llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT) + ans = llm_math.run(query) + return ans diff --git a/server/agent/tools.py b/server/agent/tools.py new file mode 100644 index 00000000..2d4476c1 --- /dev/null +++ b/server/agent/tools.py @@ -0,0 +1,28 @@ + +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from server.agent.math import calculate +from server.agent.translator import translate +from server.agent.weather import weathercheck +from langchain.agents import Tool + +tools = [ + Tool.from_function( + func=calculate, + name="计算器工具", + description="" + ), + Tool.from_function( + func=translate, + name="翻译工具", + description="" + ), + Tool.from_function( + func=weathercheck, + name="天气查询工具", + description="", + ) +] +tool_names = [tool.name for tool in tools] diff --git a/server/agent/translator.py b/server/agent/translator.py new file mode 100644 index 00000000..5740034d --- /dev/null +++ b/server/agent/translator.py @@ -0,0 +1,41 @@ +from langchain import PromptTemplate, LLMChain +import sys +import os + +from server.chat.utils import get_ChatOpenAI + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +from langchain.chains.llm_math.prompt import PROMPT +from configs.model_config import LLM_MODEL,TEMPERATURE + +_PROMPT_TEMPLATE = ''' +# 指令 +接下来,作为一个专业的翻译专家,当我给出英文句子或段落时,你将提供通顺且具有可读性的对应语言的翻译。注意: +1. 确保翻译结果流畅且易于理解 +2. 无论提供的是陈述句或疑问句,只进行翻译 +3. 不添加与原文无关的内容 + +原文: ${{用户需要翻译的原文和目标语言}} +{question} +```output +${{翻译结果}} +``` +答案: ${{答案}} +''' + +PROMPT = PromptTemplate( + input_variables=["question"], + template=_PROMPT_TEMPLATE, +) + + +def translate(query: str): + model = get_ChatOpenAI( + streaming=False, + model_name=LLM_MODEL, + temperature=TEMPERATURE, + ) + llm_translate = LLMChain(llm=model, prompt=PROMPT) + ans = llm_translate.run(query) + + return ans diff --git a/server/agent/weather.py b/server/agent/weather.py new file mode 100644 index 00000000..9a4464b6 --- /dev/null +++ b/server/agent/weather.py @@ -0,0 +1,355 @@ +## 使用和风天气API查询天气 + +from __future__ import annotations +import sys +import os + +from server.chat.utils import get_ChatOpenAI + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +import re +import warnings +from typing import Dict + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.pydantic_v1 import Extra, root_validator +from langchain.schema import BasePromptTemplate +from langchain.schema.language_model import BaseLanguageModel +import requests +from typing import List, Any, Optional +from configs.model_config import LLM_MODEL, TEMPERATURE + + +def get_city_info(location, adm, key): + base_url = 'https://geoapi.qweather.com/v2/city/lookup?' + params = {'location': location, 'adm': adm, 'key': key} + response = requests.get(base_url, params=params) + data = response.json() + return data + + +from datetime import datetime + + +def format_weather_data(data): + hourly_forecast = data['hourly'] + formatted_data = '' + for forecast in hourly_forecast: + # 将预报时间转换为datetime对象 + forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z') + # 获取预报时间的时区 + forecast_tz = forecast_time.tzinfo + # 获取当前时间(使用预报时间的时区) + now = datetime.now(forecast_tz) + # 计算预报日期与当前日期的差值 + days_diff = (forecast_time.date() - now.date()).days + if days_diff == 0: + forecast_date_str = '今天' + elif days_diff == 1: + forecast_date_str = '明天' + elif days_diff == 2: + forecast_date_str = '后天' + else: + forecast_date_str = str(days_diff) + '天后' + forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M') + # 计算预报时间与当前时间的差值 + time_diff = forecast_time - now + # 将差值转换为小时 + hours_diff = time_diff.total_seconds() // 3600 + if hours_diff < 1: + hours_diff_str = '1小时后' + elif hours_diff >= 24: + # 如果超过24小时,转换为天数 + days_diff = hours_diff // 24 + hours_diff_str = str(int(days_diff)) + '天后' + else: + hours_diff_str = str(int(hours_diff)) + '小时后' + # 将预报时间和当前时间的差值添加到输出中 + formatted_data += '预报时间: ' + hours_diff_str + '\n' + formatted_data += '具体时间: ' + forecast_time_str + '\n' + formatted_data += '温度: ' + forecast['temp'] + '°C\n' + formatted_data += '天气: ' + forecast['text'] + '\n' + formatted_data += '风向: ' + forecast['windDir'] + '\n' + formatted_data += '风速: ' + forecast['windSpeed'] + '级\n' + formatted_data += '湿度: ' + forecast['humidity'] + '%\n' + formatted_data += '降水概率: ' + forecast['pop'] + '%\n' + # formatted_data += '降水量: ' + forecast['precip'] + 'mm\n' + formatted_data += '\n\n' + return formatted_data + + +def get_weather(key, location_id, time: str = "24"): + if time: + url = "https://devapi.qweather.com/v7/weather/" + time + "h?" + else: + time = "3" # 免费订阅只能查看3天的天气 + url = "https://devapi.qweather.com/v7/weather/" + time + "d?" + params = { + 'location': location_id, + 'key': key, + } + response = requests.get(url, params=params) + data = response.json() + return format_weather_data(data) + + +def split_query(query): + parts = query.split() + location = parts[0] if parts[0] != 'None' else parts[1] + adm = parts[1] + time = parts[2] + return location, adm, time + + +def weather(query): + location, adm, time = split_query(query) + if time != "None" and int(time) > 24: + return "只能查看24小时内的天气,无法回答" + if time == "None": + time = "24" # 免费的版本只能24小时内的天气 + key = "" # 和风天气API Key + if key == "": + return "请先在代码中填入和风天气API Key" + city_info = get_city_info(location=location, adm=adm, key=key) + location_id = city_info['location'][0]['id'] + weather_data = get_weather(key=key, location_id=location_id, time=time) + return weather_data + + +class LLMWeatherChain(Chain): + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" + prompt: BasePromptTemplate + """[Deprecated] Prompt to use to translate to python if necessary.""" + input_key: str = "question" #: :meta private: + output_key: str = "answer" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMWeatherChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the from_llm " + "class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + prompt = values.get("prompt", PROMPT) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Expect output key. + + :meta private: + """ + return [self.output_key] + + def _evaluate_expression(self, expression: str) -> str: + try: + output = weather(expression) + except Exception as e: + output = "输入的信息有误,请再次尝试" + # raise ValueError(f"错误: {expression},输入的信息不对") + + return output + + def _process_llm_result( + self, llm_output: str, run_manager: CallbackManagerForChainRun + ) -> Dict[str, str]: + + run_manager.on_text(llm_output, color="green", verbose=self.verbose) + + llm_output = llm_output.strip() + text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + if text_match: + expression = text_match.group(1) + output = self._evaluate_expression(expression) + run_manager.on_text("\nAnswer: ", verbose=self.verbose) + run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = "Answer: " + llm_output.split("Answer:")[-1] + else: + raise ValueError(f"unknown format from LLM: {llm_output}") + return {self.output_key: answer} + + async def _aprocess_llm_result( + self, + llm_output: str, + run_manager: AsyncCallbackManagerForChainRun, + ) -> Dict[str, str]: + await run_manager.on_text(llm_output, color="green", verbose=self.verbose) + llm_output = llm_output.strip() + text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) + if text_match: + expression = text_match.group(1) + output = self._evaluate_expression(expression) + await run_manager.on_text("\nAnswer: ", verbose=self.verbose) + await run_manager.on_text(output, color="yellow", verbose=self.verbose) + answer = "Answer: " + output + elif llm_output.startswith("Answer:"): + answer = llm_output + elif "Answer:" in llm_output: + answer = "Answer: " + llm_output.split("Answer:")[-1] + else: + raise ValueError(f"unknown format from LLM: {llm_output}") + return {self.output_key: answer} + + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key]) + llm_output = self.llm_chain.predict( + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return self._process_llm_result(llm_output, _run_manager) + + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + await _run_manager.on_text(inputs[self.input_key]) + llm_output = await self.llm_chain.apredict( + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), + ) + return await self._aprocess_llm_result(llm_output, _run_manager) + + @property + def _chain_type(self) -> str: + return "llm_weather_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: BasePromptTemplate, + **kwargs: Any, + ) -> LLMWeatherChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) + + +from langchain import PromptTemplate + +_PROMPT_TEMPLATE = """用户将会向您咨询天气问题,您不需要自己回答天气问题,而是将用户提问的信息提取出来区,市和时间三个元素后使用我为你编写好的工具进行查询并返回结果,格式为 区+市+时间 每个元素用空格隔开。如果缺少信息,则用 None 代替。 +问题: ${{用户的问题}} + +```text + +${{拆分的区,市和时间}} +``` + +... weather(query)... +```output + +${{提取后的答案}} +``` +答案: ${{答案}} + +这是两个例子: +问题: 上海浦东未来1小时天气情况? + +```text + +浦东 上海 1 +``` +...weather(浦东 上海 1)... + +```output + +预报时间: 1小时后 +具体时间: 今天 18:00 +温度: 24°C +天气: 多云 +风向: 西南风 +风速: 7级 +湿度: 88% +降水概率: 16% + +Answer: +预报时间: 1小时后 +具体时间: 今天 18:00 +温度: 24°C +天气: 多云 +风向: 西南风 +风速: 7级 +湿度: 88% +降水概率: 16% + +问题: 北京市朝阳区未来24小时天气如何? +```text + +朝阳 北京 24 +``` +...weather(朝阳 北京 24)... +```output +预报时间: 23小时后 +具体时间: 明天 17:00 +温度: 26°C +天气: 霾 +风向: 西南风 +风速: 11级 +湿度: 65% +降水概率: 20% +Answer: +预报时间: 23小时后 +具体时间: 明天 17:00 +温度: 26°C +天气: 霾 +风向: 西南风 +风速: 11级 +湿度: 65% +降水概率: 20% + +现在,这是我的问题: +问题: {question} +""" +PROMPT = PromptTemplate( + input_variables=["question"], + template=_PROMPT_TEMPLATE, +) + + +def weathercheck(query: str): + model = get_ChatOpenAI( + streaming=False, + model_name=LLM_MODEL, + temperature=TEMPERATURE, + ) + llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT) + ans = llm_weather.run(query) + return ans + diff --git a/server/api.py b/server/api.py index 91326ffa..ea098d68 100644 --- a/server/api.py +++ b/server/api.py @@ -12,12 +12,12 @@ import uvicorn from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, - search_engine_chat) + search_engine_chat, agent_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, update_docs, download_doc, recreate_vector_store, search_docs, DocumentWithScore) -from server.llm_api import list_running_models,list_config_models, change_llm_model, stop_llm_model +from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline from typing import List @@ -67,6 +67,10 @@ def create_app(): tags=["Chat"], summary="与搜索引擎对话")(search_engine_chat) + app.post("/chat/agent_chat", + tags=["Chat"], + summary="与agent对话")(agent_chat) + # Tag: Knowledge Base Management app.get("/knowledge_base/list_knowledge_bases", tags=["Knowledge Base Management"], @@ -126,24 +130,24 @@ def create_app(): # LLM模型相关接口 app.post("/llm_model/list_running_models", - tags=["LLM Model Management"], - summary="列出当前已加载的模型", - )(list_running_models) + tags=["LLM Model Management"], + summary="列出当前已加载的模型", + )(list_running_models) app.post("/llm_model/list_config_models", - tags=["LLM Model Management"], - summary="列出configs已配置的模型", - )(list_config_models) + tags=["LLM Model Management"], + summary="列出configs已配置的模型", + )(list_config_models) app.post("/llm_model/stop", - tags=["LLM Model Management"], - summary="停止指定的LLM模型(Model Worker)", - )(stop_llm_model) + tags=["LLM Model Management"], + summary="停止指定的LLM模型(Model Worker)", + )(stop_llm_model) app.post("/llm_model/change", - tags=["LLM Model Management"], - summary="切换指定的LLM模型(Model Worker)", - )(change_llm_model) + tags=["LLM Model Management"], + summary="切换指定的LLM模型(Model Worker)", + )(change_llm_model) return app diff --git a/server/chat/__init__.py b/server/chat/__init__.py index 136bad64..62fe430c 100644 --- a/server/chat/__init__.py +++ b/server/chat/__init__.py @@ -2,3 +2,4 @@ from .chat import chat from .knowledge_base_chat import knowledge_base_chat from .openai_chat import openai_chat from .search_engine_chat import search_engine_chat +from .agent_chat import agent_chat \ No newline at end of file diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py new file mode 100644 index 00000000..8d7cf3b5 --- /dev/null +++ b/server/chat/agent_chat.py @@ -0,0 +1,73 @@ +from langchain.memory import ConversationBufferWindowMemory +from server.agent.tools import tools, tool_names +from langchain.agents import AgentExecutor, LLMSingleActionAgent +from server.agent.custom_template import CustomOutputParser, prompt +from fastapi import Body +from fastapi.responses import StreamingResponse +from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN +from server.chat.utils import wrap_done, get_ChatOpenAI +from langchain import LLMChain +from langchain.callbacks import AsyncIteratorCallbackHandler +from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler +from typing import AsyncIterable +import asyncio +from langchain.prompts.chat import ChatPromptTemplate +from typing import List +from server.chat.utils import History + +memory = ConversationBufferWindowMemory(k=HISTORY_LEN) +async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), + history: List[History] = Body([], + description="历史对话", + examples=[[ + {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", "content": "虎头虎脑"}]] + ), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), + # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), + ): + history = [History.from_data(h) for h in history] + + async def chat_iterator(query: str, + history: List[History] = [], + model_name: str = LLM_MODEL, + ) -> AsyncIterable[str]: + callback = AsyncFinalIteratorCallbackHandler() + model = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + callbacks=[callback], + ) + output_parser = CustomOutputParser() + llm_chain = LLMChain(llm=model, prompt=prompt) + agent = LLMSingleActionAgent( + llm_chain=llm_chain, + output_parser=output_parser, + stop=["\nObservation:"], + allowed_tools=tool_names + ) + agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory, + callbacks=[callback]) + input_msg = History(role="user", content="{{ input }}").to_msg_template(False) + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_template() for i in history] + [input_msg]) + task = asyncio.create_task(wrap_done( + agent_executor.acall(query), + callback.done), + ) + if stream: + async for token in callback.aiter(): + # Use server-sent-events to stream the response + yield token + else: + answer = "" + async for token in callback.aiter(): + answer += token + yield answer + + await task + + return StreamingResponse(chat_iterator(query, history, model_name), + media_type="text/event-stream") diff --git a/server/chat/utils.py b/server/chat/utils.py index 0c32969e..0dd17d39 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -10,11 +10,12 @@ from typing import Awaitable, List, Tuple, Dict, Union, Callable def get_ChatOpenAI( model_name: str, temperature: float, + streaming: bool = True, callbacks: List[Callable] = [], ) -> ChatOpenAI: config = get_model_worker_config(model_name) model = ChatOpenAI( - streaming=True, + streaming=streaming, verbose=True, callbacks=callbacks, openai_api_key=config.get("api_key", "EMPTY"), diff --git a/tests/agent/test_agent_function.py b/tests/agent/test_agent_function.py new file mode 100644 index 00000000..91c3a327 --- /dev/null +++ b/tests/agent/test_agent_function.py @@ -0,0 +1,40 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +from configs import LLM_MODEL, TEMPERATURE +from server.chat.utils import get_ChatOpenAI +from langchain import LLMChain +from langchain.agents import LLMSingleActionAgent, AgentExecutor +from server.agent.tools import tools, tool_names +from langchain.memory import ConversationBufferWindowMemory + +memory = ConversationBufferWindowMemory(k=5) +model = get_ChatOpenAI( + model_name=LLM_MODEL, + temperature=TEMPERATURE, + ) +from server.agent.custom_template import CustomOutputParser, prompt + +output_parser = CustomOutputParser() +llm_chain = LLMChain(llm=model, prompt=prompt) +agent = LLMSingleActionAgent( + llm_chain=llm_chain, + output_parser=output_parser, + stop=["\nObservation:"], + allowed_tools=tool_names +) + +agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, memory=memory, verbose=True) + +import pytest +@pytest.mark.parametrize("text_prompt", + ["北京市朝阳区未来24小时天气如何?", # 天气功能函数 + "计算 (2 + 2312312)/4 是多少?", # 计算功能函数 + "翻译这句话成中文:Life is the art of drawing sufficient conclusions form insufficient premises."] # 翻译功能函数 +) +def test_different_agent_function(text_prompt): + try: + text_answer = agent_executor.run(text_prompt) + assert text_answer is not None + except Exception as e: + pytest.fail(f"agent_function failed with {text_prompt}, error: {str(e)}") diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 0d6a2ad0..17555362 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -58,6 +58,7 @@ def dialogue_page(api: ApiRequest): ["LLM 对话", "知识库问答", "搜索引擎问答", + "自定义Agent问答", ], index=1, on_change=on_mode_change, @@ -152,6 +153,19 @@ def dialogue_page(api: ApiRequest): text += t chat_box.update_msg(text) chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 + + elif dialogue_mode == "自定义Agent问答": + chat_box.ai_say("正在调用工具回答...") + text = "" + r = api.agent_chat(prompt, history=history, model=llm_model, temperature=temperature) + for t in r: + if error_msg := check_error_msg(t): # check whether error occured + st.error(error_msg) + break + text += t + chat_box.update_msg(text) + chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 + elif dialogue_mode == "知识库问答": history = get_messages_history(history_len) chat_box.ai_say([ diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 5708efce..a94cb995 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -342,6 +342,39 @@ class ApiRequest: response = self.post("/chat/chat", json=data, stream=True) return self._httpx_stream2generator(response) + def agent_chat( + self, + query: str, + history: List[Dict] = [], + stream: bool = True, + model: str = LLM_MODEL, + temperature: float = TEMPERATURE, + no_remote_api: bool = None, + ): + ''' + 对应api.py/chat/agent_chat 接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "query": query, + "history": history, + "stream": stream, + "model_name": model, + "temperature": temperature, + } + + print(f"received input message:") + pprint(data) + + if no_remote_api: + from server.chat.agent_chat import agent_chat + response = run_async(agent_chat(**data)) + return self._fastapi_stream2generator(response) + else: + response = self.post("/chat/agent_chat", json=data, stream=True) + return self._httpx_stream2generator(response) def knowledge_base_chat( self, query: str,