diff --git a/server/agent/agent_factory/glm3_agent.py b/server/agent/agent_factory/glm3_agent.py index b3f3148e..9deaa71f 100644 --- a/server/agent/agent_factory/glm3_agent.py +++ b/server/agent/agent_factory/glm3_agent.py @@ -3,14 +3,6 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file """ from __future__ import annotations -import yaml -from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser -from langchain.memory import ConversationBufferWindowMemory -from typing import Any, List, Sequence, Tuple, Optional, Union -import os -from langchain.agents.agent import Agent -from langchain.chains.llm import LLMChain -from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate import json import logging from typing import Any, List, Sequence, Tuple, Optional, Union @@ -30,7 +22,8 @@ from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import BaseCallbackManager from langchain.schema.language_model import BaseLanguageModel from langchain.tools.base import BaseTool -from pydantic.schema import model_schema +from langchain_core.callbacks import Callbacks + HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" logger = logging.getLogger(__name__) @@ -195,6 +188,7 @@ class StructuredGLM3ChatAgent(Agent): llm=llm, prompt=prompt, callback_manager=callback_manager, + verbose=True ) tool_names = [tool.name for tool in tools] _output_parser = output_parser or cls._get_default_output_parser(llm=llm) diff --git a/server/agent/tools_factory/search_local_knowledgebase.py b/server/agent/tools_factory/search_local_knowledgebase.py index 0bc7803c..68061815 100644 --- a/server/agent/tools_factory/search_local_knowledgebase.py +++ b/server/agent/tools_factory/search_local_knowledgebase.py @@ -1,5 +1,4 @@ from urllib.parse import urlencode - from pydantic import BaseModel, Field from server.knowledge_base.kb_doc_api import search_docs @@ -9,7 +8,7 @@ from configs import TOOL_CONFIG def search_knowledgebase(query: str, database: str, config: dict): docs = search_docs( query=query, - knowledge_base_name=database, + knowledge_base_name="samples", top_k=config["top_k"], score_threshold=config["score_threshold"]) context = "" @@ -22,7 +21,7 @@ def search_knowledgebase(query: str, database: str, config: dict): source_documents.append(text) if len(source_documents) == 0: - context= "没有找到相关文档,请更换关键词重试" + context = "没有找到相关文档,请更换关键词重试" else: for doc in source_documents: context += doc + "\n" @@ -37,4 +36,4 @@ class SearchKnowledgeInput(BaseModel): def search_local_knowledgebase(database: str, query: str): tool_config = TOOL_CONFIG["search_local_knowledgebase"] - return search_knowledgebase(query=query, database=database, config=tool_config) + return search_knowledgebase(query=query, database=database, config=tool_config) \ No newline at end of file diff --git a/server/agent/tools_factory/tools_registry.py b/server/agent/tools_factory/tools_registry.py index 6e9e8e0c..f84a7c82 100644 --- a/server/agent/tools_factory/tools_registry.py +++ b/server/agent/tools_factory/tools_registry.py @@ -2,9 +2,9 @@ from langchain_core.tools import StructuredTool from server.agent.tools_factory import * from configs import KB_INFO -template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool." +template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) -template_knowledge = template.format(KB_info=KB_info_str) +template_knowledge = template.format(KB_info=KB_info_str, key="samples") all_tools = [ StructuredTool.from_function( @@ -24,6 +24,7 @@ all_tools = [ name="shell", description="Use Shell to execute Linux commands", args_schema=ShellInput, + # return_direct=True, #是否直接返回,不做大模型处理 ), StructuredTool.from_function( func=wolfram, diff --git a/server/callback_handler/agent_callback_handler.py b/server/callback_handler/agent_callback_handler.py index dad07a79..655bf6e5 100644 --- a/server/callback_handler/agent_callback_handler.py +++ b/server/callback_handler/agent_callback_handler.py @@ -1,11 +1,13 @@ from __future__ import annotations from uuid import UUID import json -from langchain.schema import AgentFinish, AgentAction import asyncio -from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional +from typing import Any, Dict, List, Optional + +from langchain.callbacks import AsyncIteratorCallbackHandler +from langchain.schema import AgentFinish, AgentAction from langchain_core.outputs import LLMResult -from langchain.callbacks.base import AsyncCallbackHandler + def dumps(obj: Dict) -> str: return json.dumps(obj, ensure_ascii=False) @@ -21,7 +23,7 @@ class Status: tool_finish: int = 7 -class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler): +class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): def __init__(self): super().__init__() self.queue = asyncio.Queue() @@ -29,29 +31,31 @@ class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler): self.cur_tool = {} self.out = True - async def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - print("on_tool_start") + async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, + parent_run_id: UUID | None = None, tags: List[str] | None = None, + metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None: + self.cur_tool = { + "tool_name": serialized["name"], + "input_str": input_str, + "output_str": "", + "status": Status.agent_action, + "run_id": run_id.hex, + "llm_token": "", + "final_answer": "", + "error": "", + } + self.queue.put_nowait(dumps(self.cur_tool)) + + async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, + tags: List[str] | None = None, **kwargs: Any) -> None: + + self.out = True + self.cur_tool.update( + status=Status.tool_finish, + output_str=output.replace("Answer:", ""), + ) + self.queue.put_nowait(dumps(self.cur_tool)) - async def on_tool_end( - self, - output: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - print("on_tool_end") async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: self.cur_tool.update( @@ -134,29 +138,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler): tool_input=action.tool_input, ) self.queue.put_nowait(dumps(self.cur_tool)) + async def on_agent_finish( self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: + if "Thought:" in finish.return_values["output"]: + finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "") + self.cur_tool.update( status=Status.agent_finish, agent_finish=finish.return_values["output"], ) self.queue.put_nowait(dumps(self.cur_tool)) - - async def aiter(self) -> AsyncIterator[str]: - while not self.queue.empty() or not self.done.is_set(): - done, other = await asyncio.wait( - [ - asyncio.ensure_future(self.queue.get()), - asyncio.ensure_future(self.done.wait()), - ], - return_when=asyncio.FIRST_COMPLETED, - ) - if other: - other.pop().cancel() - token_or_done = cast(Union[str, Literal[True]], done.pop().result()) - if token_or_done is True: - break - yield token_or_done diff --git a/server/chat/chat.py b/server/chat/chat.py index 7afdeacb..75757d57 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -6,12 +6,11 @@ from fastapi import Body from fastapi.responses import StreamingResponse from langchain.agents import initialize_agent, AgentType -from langchain_core.callbacks import BaseCallbackManager from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnableBranch from langchain.chains import LLMChain from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts import PromptTemplate +from langchain_core.runnables import RunnableBranch from server.agent.agent_factory import initialize_glm3_agent from server.agent.tools_factory.tools_registry import all_tools @@ -57,48 +56,55 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) elif conversation_id and history_len > 0: - memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"], - message_limit=history_len) + memory = ConversationBufferDBMemory( + conversation_id=conversation_id, + llm=models["llm_model"], + message_limit=history_len + ) else: input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_msg]) - chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory) + chain = LLMChain( + prompt=chat_prompt, + llm=models["llm_model"], + memory=memory + ) classifier_chain = ( PromptTemplate.from_template(prompts["preprocess_model"]) | models["preprocess_model"] | StrOutputParser() ) - if "chatglm3" in models["action_model"].model_name.lower(): - agent_executor = initialize_glm3_agent( - llm=models["action_model"], - tools=tools, - prompt=prompts["action_model"], - input_variables=["input", "intermediate_steps", "history"], - memory=memory, - callback_manager=BaseCallbackManager(handlers=callbacks), - verbose=True, - ) - else: - agent_executor = initialize_agent( - llm=models["action_model"], - tools=tools, - callbacks=callbacks, - agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, - memory=memory, - verbose=True, - ) - agent_use = False - if agent_use: - branch = RunnableBranch( - (lambda x: "1" in x["topic"].lower(), agent_executor), - chain - ) - full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) - else: - # full_chain = ({"input": lambda x: x["input"]} | chain) + if "action_model" in models and tools: + if "chatglm3" in models["action_model"].model_name.lower(): + agent_executor = initialize_glm3_agent( + llm=models["action_model"], + tools=tools, + prompt=prompts["action_model"], + input_variables=["input", "intermediate_steps", "history"], + memory=memory, + # callback_manager=BaseCallbackManager(handlers=callbacks), + verbose=True, + ) + else: + agent_executor = initialize_agent( + llm=models["action_model"], + tools=tools, + callbacks=callbacks, + agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, + memory=memory, + verbose=True, + ) + + # branch = RunnableBranch( + # (lambda x: "1" in x["topic"].lower(), agent_executor), + # chain + # ) + # full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) full_chain = ({"input": lambda x: x["input"]} | agent_executor) + else: + full_chain = ({"input": lambda x: x["input"]} | chain) return full_chain @@ -149,7 +155,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 # Execute Chain - task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done)) + task = asyncio.create_task( + wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done)) if stream: async for chunk in callback.aiter(): data = json.loads(chunk) diff --git a/server/utils.py b/server/utils.py index fa1954da..fa3ec0ae 100644 --- a/server/utils.py +++ b/server/utils.py @@ -508,30 +508,33 @@ def set_httpx_config( urllib.request.getproxies = _get_proxies -def detect_device() -> Literal["cuda", "mps", "cpu"]: +def detect_device() -> Literal["cuda", "mps", "cpu", "xpu"]: try: import torch if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available(): return "mps" + import intel_extension_for_pytorch as ipex + if torch.xpu.get_device_properties(0): + return "xpu" except: pass return "cpu" -def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: +def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: device = device or LLM_DEVICE # if device.isdigit(): # return "cuda:" + device - if device not in ["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu", "xpu"]: device = detect_device() return device -def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: +def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: device = device or EMBEDDING_DEVICE - if device not in ["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu", "xpu"]: device = detect_device() return device diff --git a/startup.py b/startup.py index 6a8856ec..380160df 100644 --- a/startup.py +++ b/startup.py @@ -219,7 +219,6 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: wbits=args.awq_wbits, groupsize=args.awq_groupsize, ) - worker = ModelWorker( controller_addr=args.controller_address, worker_addr=args.worker_address, @@ -391,7 +390,6 @@ def run_model_worker( kwargs["worker_address"] = fschat_model_worker_address(model_name) model_path = kwargs.get("model_path", "") kwargs["model_path"] = model_path - app = create_model_worker_app(log_level=log_level, **kwargs) _set_app_event(app, started_event) if log_level == "ERROR": diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 34178e66..fe1cfe18 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -10,8 +10,7 @@ from datetime import datetime import os import re import time -from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG) -from server.knowledge_base.utils import LOADER_DICT +from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS) import uuid from typing import List, Dict @@ -124,15 +123,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): conv_names = list(st.session_state["conversation_ids"].keys()) index = 0 - tools = list(TOOL_CONFIG.keys()) - selected_tool_configs = {} - - with st.expander("工具栏"): - for tool in tools: - is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool) - if is_selected: - selected_tool_configs[tool] = TOOL_CONFIG[tool] - if st.session_state.get("cur_conv_name") in conv_names: index = conv_names.index(st.session_state.get("cur_conv_name")) conversation_name = st.selectbox("当前会话", conv_names, index=index) @@ -177,7 +167,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): for k, v in config_models.get("online", {}).items(): if not v.get("provider") and k not in running_models and k in LLM_MODELS: available_models.append(k) - llm_models = running_models + available_models + ["openai-api"] + llm_models = running_models + available_models # + ["openai-api"] cur_llm_model = st.session_state.get("cur_llm_model", default_model) if cur_llm_model in llm_models: index = llm_models.index(cur_llm_model) @@ -193,22 +183,39 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): # 传入后端的内容 model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()} - + tool_use = True for key in LLM_MODEL_CONFIG: if key == 'llm_model': continue + if key == 'action_model': + first_key = next(iter(LLM_MODEL_CONFIG[key])) + if first_key not in SUPPORT_AGENT_MODELS: + st.warning("不支持Agent的模型,无法执行任何工具调用") + tool_use = False + continue if LLM_MODEL_CONFIG[key]: first_key = next(iter(LLM_MODEL_CONFIG[key])) model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key] + # 选择工具 + selected_tool_configs = {} + if tool_use: + from configs import prompt_config + import importlib + importlib.reload(prompt_config) + + tools = list(prompt_config.TOOL_CONFIG.keys()) + with st.expander("工具栏"): + for tool in tools: + is_selected = st.checkbox(tool, value=prompt_config.TOOL_CONFIG[tool]["use"], key=tool) + if is_selected: + selected_tool_configs[tool] = prompt_config.TOOL_CONFIG[tool] + if llm_model is not None: model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model] - # files = st.file_uploader("上传附件",accept_multiple_files=False) - # type=[i for ls in LOADER_DICT.values() for i in ls],) uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False) files_upload = process_files(files=[uploaded_file]) if uploaded_file else None - # print(len(files_upload["audios"])) if files_upload else None # if dialogue_mode == "文件对话": @@ -355,6 +362,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): chat_box.reset_history() st.rerun() + warning_placeholder = st.empty() + with warning_placeholder.container(): + st.warning('Running in 8 x A100') + export_btn.download_button( "导出记录", "".join(chat_box.export2md()),