diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index 2f8dd2aa..4f9c45f5 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -1,13 +1,11 @@ PROMPT_TEMPLATES = { "preprocess_model": { "default": - '你只要回复0 和 1 ,代表不需要使用工具。以下几种问题不需要使用工具:' - '1. 介绍一下你自己, 回复0\n' - '2. 讲一个故事, 回复0\n' - '3. 给我开一个玩笑, 回复0\n' - '4. 我当前运行的文件夹是, 回复1\n' - '5. where is this cat, 回复1\n' - '6. 介绍一下像极了我大学, 回复1\n' + '你只要回复0 和 1 ,代表不需要使用工具。以下几种问题需要使用工具:' + '1. 需要联网查询的内容\n' + '2. 需要计算的内容\n' + '3. 需要查询实时性的内容\n' + '如果我的输入满足这几种情况,返回1。其他输入,请你回复0,你只要返回一个数字\n' '这是我的问题:' }, "llm_model": { diff --git a/requirements.txt b/requirements.txt index e55f9ff6..7bb097d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,12 @@ # API requirements -langchain>=0.0.350 -langchain-experimental>=0.0.42 +# On Windows system, install the cuda version manually from https://pytorch.org/ +torch>=2.1.2 +torchvision>=0.16.2 +torchaudio>=2.1.2 + +langchain>=0.0.352 +langchain-experimental>=0.0.47 pydantic==1.10.13 fschat==0.2.35 openai==1.9.0 diff --git a/server/agent/agent_factory/glm3_agent.py b/server/agent/agent_factory/glm3_agent.py index f9bf3aee..4be21c1b 100644 --- a/server/agent/agent_factory/glm3_agent.py +++ b/server/agent/agent_factory/glm3_agent.py @@ -8,7 +8,6 @@ import logging from typing import Any, List, Sequence, Tuple, Optional, Union from pydantic.schema import model_schema - from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser from langchain.memory import ConversationBufferWindowMemory from langchain.agents.agent import Agent @@ -24,7 +23,6 @@ from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool from langchain_core.callbacks import Callbacks - HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" logger = logging.getLogger(__name__) @@ -177,6 +175,7 @@ class StructuredGLM3ChatAgent(Agent): **kwargs: Any, ) -> Agent: """Construct an agent from an LLM and tools.""" + cls._validate_tools(tools) prompt = cls.create_prompt( tools, @@ -217,7 +216,7 @@ def initialize_glm3_agent( ) -> AgentExecutor: tags_ = list(tags) if tags else [] agent_kwargs = agent_kwargs or {} - llm.callbacks=callbacks + llm.callbacks = callbacks agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools( llm=llm, tools=tools, diff --git a/server/agent/agent_factory/qwen_agent.py b/server/agent/agent_factory/qwen_agent.py index f31f1ae5..ee97d407 100644 --- a/server/agent/agent_factory/qwen_agent.py +++ b/server/agent/agent_factory/qwen_agent.py @@ -154,15 +154,11 @@ def initialize_qwen_agent( memory: Optional[ConversationBufferWindowMemory] = None, agent_kwargs: Optional[dict] = None, *, - return_direct: Optional[bool] = None, tags: Optional[Sequence[str]] = None, **kwargs: Any, ) -> AgentExecutor: tags_ = list(tags) if tags else [] agent_kwargs = agent_kwargs or {} - - if isinstance(return_direct, bool): # can make all tools return directly - tools = [t.copy(update={"return_direct": return_direct}) for t in tools] llm.callbacks=callbacks agent_obj = QwenChatAgent.from_llm_and_tools( llm=llm, diff --git a/server/agent/tools_factory/tools_registry.py b/server/agent/tools_factory/tools_registry.py index f84a7c82..7334fdd3 100644 --- a/server/agent/tools_factory/tools_registry.py +++ b/server/agent/tools_factory/tools_registry.py @@ -24,13 +24,13 @@ all_tools = [ name="shell", description="Use Shell to execute Linux commands", args_schema=ShellInput, - # return_direct=True, #是否直接返回,不做大模型处理 ), StructuredTool.from_function( func=wolfram, name="wolfram", description="Useful for when you need to calculate difficult formulas", args_schema=WolframInput, + ), StructuredTool.from_function( func=search_youtube, @@ -67,5 +67,6 @@ all_tools = [ name="aqa_processor", description="use this tool to get answer for audio question", args_schema=AQAInput, + ) ] diff --git a/server/callback_handler/agent_callback_handler.py b/server/callback_handler/agent_callback_handler.py index 23949b5e..b477ad2e 100644 --- a/server/callback_handler/agent_callback_handler.py +++ b/server/callback_handler/agent_callback_handler.py @@ -19,7 +19,9 @@ class AgentStatus: llm_end: int = 3 agent_action: int = 4 agent_finish: int = 5 - error: int = 6 + tool_begin: int = 6 + tool_end: int = 7 + error: int = 8 class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): @@ -31,20 +33,19 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: data = { - "status" : AgentStatus.llm_start, - "text" : "", + "status": AgentStatus.llm_start, + "text": "", } self.done.clear() self.queue.put_nowait(dumps(data)) - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: special_tokens = ["Action", "<|observation|>"] for stoken in special_tokens: if stoken in token: before_action = token.split(stoken)[0] data = { - "status" : AgentStatus.llm_new_token, + "status": AgentStatus.llm_new_token, "text": before_action + "\n", } self.queue.put_nowait(dumps(data)) @@ -53,15 +54,29 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): if token is not None and token != "" and self.out: data = { - "status" : AgentStatus.llm_new_token, - "text" : token, + "status": AgentStatus.llm_new_token, + "text": token, } self.queue.put_nowait(dumps(data)) - async def on_chat_model_start( + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + data = { + "status": AgentStatus.llm_end, + "text": "", + } + self.queue.put_nowait(dumps(data)) + + async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None: + data = { + "status": AgentStatus.error, + "text": str(error), + } + self.queue.put_nowait(dumps(data)) + + async def on_tool_start( self, serialized: Dict[str, Any], - messages: List[List], + input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -69,26 +84,40 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: + print("tool_begin") + + async def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool ends running.""" data = { - "status" : AgentStatus.llm_start, - "text" : "", + "status": AgentStatus.tool_end, + "tool_output": output, } self.done.clear() self.queue.put_nowait(dumps(data)) - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + async def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool errors.""" data = { - "status" : AgentStatus.llm_end, - "text" : response.generations[0][0].message.content, - } - self.out = True - self.queue.put_nowait(dumps(data)) - - async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None: - data = { - "status" : AgentStatus.error, - "text" : str(error), + "status": AgentStatus.tool_end, + "text": error, } + self.done.clear() self.queue.put_nowait(dumps(data)) async def on_agent_action( @@ -101,9 +130,9 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): **kwargs: Any, ) -> None: data = { - "status" : AgentStatus.agent_action, - "tool_name" : action.tool, - "tool_input" : action.tool_input, + "status": AgentStatus.agent_action, + "tool_name": action.tool, + "tool_input": action.tool_input, "text": action.log, } self.queue.put_nowait(dumps(data)) @@ -113,12 +142,16 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: + if "Thought:" in finish.return_values["output"]: finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "") data = { - "status" : AgentStatus.agent_finish, - "text" : finish.return_values["output"], + "status": AgentStatus.agent_finish, + "text": finish.return_values["output"], } + self.done.set() self.queue.put_nowait(dumps(data)) + + self.out = True diff --git a/server/chat/chat.py b/server/chat/chat.py index 74933ba2..b8e68be8 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -69,6 +69,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks chain = LLMChain( prompt=chat_prompt, llm=models["llm_model"], + callbacks=callbacks, memory=memory ) classifier_chain = ( @@ -76,8 +77,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks | models["preprocess_model"] | StrOutputParser() ) - - if "action_model" in models and tools: + if "action_model" in models and len(tools) > 0: if "chatglm3" in models["action_model"].model_name.lower(): agent_executor = initialize_glm3_agent( llm=models["action_model"], @@ -151,7 +151,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 # 从配置中选择工具 tools = [tool for tool in all_tools if tool.name in tool_config] - + tools = [t.copy(update={"callbacks": callbacks}) for t in tools] # 构建完整的Chain full_chain = create_models_chains(prompts=prompts, models=models, @@ -161,12 +161,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 history=history, history_len=history_len, metadata=metadata) - - # Execute Chain - - task = asyncio.create_task( - wrap_done(full_chain.ainvoke({"input": query}), callback.done)) - + task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}), callback.done)) async for chunk in callback.aiter(): data = json.loads(chunk) data["message_id"] = message_id diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 9ed2df7f..9c2f575e 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -272,12 +272,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): f'', unsafe_allow_html=True) - chat_box.ai_say(["正在思考...", - Markdown(title="tool call", in_expander=True, expanded=True,state="running"), - Markdown()]) + chat_box.ai_say("正在思考...") text = "" - message_id = "" - tool_called = False + text_action = "" + element_index = 0 for d in api.chat_chat(query=prompt, metadata=files_upload, @@ -290,30 +288,33 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): metadata = { "message_id": message_id, } - if not tool_called: # 避免工具调用之后重复输出,将LLM输出分为工具调用前后分别处理 - element_index = 0 - else: - element_index = -1 if d["status"] == AgentStatus.error: st.error(d["text"]) elif d["status"] == AgentStatus.agent_action: formatted_data = { - "action": d["tool_name"], - "action_input": d["tool_input"] + "Function": d["tool_name"], + "function_input": d["tool_input"] } + element_index += 1 formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False) - chat_box.update_msg(Markdown(f"\n```\nInput Params:\n" + formatted_json + f"\n```\n"), element_index=1) - tool_called = True - text = "" + chat_box.insert_msg( + Markdown(title="Function call", in_expander=True, expanded=True, state="running")) + text = """\n```{}\n```\n""".format(formatted_json) + chat_box.update_msg(Markdown(text), element_index=element_index) + elif d["status"] == AgentStatus.tool_end: + text += """\n```\nObservation:\n{}\n```\n""".format(d["tool_output"]) + chat_box.update_msg(Markdown(text), element_index=element_index, expanded=False, state="complete") elif d["status"] == AgentStatus.llm_new_token: text += d["text"] chat_box.update_msg(text, streaming=True, element_index=element_index, metadata=metadata) elif d["status"] == AgentStatus.llm_end: chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) elif d["status"] == AgentStatus.agent_finish: - chat_box.update_msg(element_index=1, state="complete", expanded=False) - chat_box.update_msg(Markdown(d["text"]), streaming=False, element_index=-1) + element_index += 1 + # print(d["text"]) + chat_box.insert_msg(Markdown(d["text"], expanded=True)) + chat_box.update_msg(Markdown(d["text"]), element_index=element_index) if os.path.exists("tmp/image.jpg"): with open("tmp/image.jpg", "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode()