支持XPU,修改了glm3部分agent

This commit is contained in:
zR 2023-12-13 18:19:51 +08:00 committed by liunux4odoo
parent d44ce6ce21
commit 6d3d99639e
8 changed files with 119 additions and 114 deletions

View File

@ -3,14 +3,6 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file
"""
from __future__ import annotations
import yaml
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from typing import Any, List, Sequence, Tuple, Optional, Union
import os
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
import json
import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
@ -30,7 +22,8 @@ from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
from pydantic.schema import model_schema
from langchain_core.callbacks import Callbacks
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__)
@ -195,6 +188,7 @@ class StructuredGLM3ChatAgent(Agent):
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
verbose=True
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)

View File

@ -1,5 +1,4 @@
from urllib.parse import urlencode
from pydantic import BaseModel, Field
from server.knowledge_base.kb_doc_api import search_docs
@ -9,7 +8,7 @@ from configs import TOOL_CONFIG
def search_knowledgebase(query: str, database: str, config: dict):
docs = search_docs(
query=query,
knowledge_base_name=database,
knowledge_base_name="samples",
top_k=config["top_k"],
score_threshold=config["score_threshold"])
context = ""
@ -22,7 +21,7 @@ def search_knowledgebase(query: str, database: str, config: dict):
source_documents.append(text)
if len(source_documents) == 0:
context= "没有找到相关文档,请更换关键词重试"
context = "没有找到相关文档,请更换关键词重试"
else:
for doc in source_documents:
context += doc + "\n"
@ -37,4 +36,4 @@ class SearchKnowledgeInput(BaseModel):
def search_local_knowledgebase(database: str, query: str):
tool_config = TOOL_CONFIG["search_local_knowledgebase"]
return search_knowledgebase(query=query, database=database, config=tool_config)
return search_knowledgebase(query=query, database=database, config=tool_config)

View File

@ -2,9 +2,9 @@ from langchain_core.tools import StructuredTool
from server.agent.tools_factory import *
from configs import KB_INFO
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool."
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str)
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
all_tools = [
StructuredTool.from_function(
@ -24,6 +24,7 @@ all_tools = [
name="shell",
description="Use Shell to execute Linux commands",
args_schema=ShellInput,
# return_direct=True, #是否直接返回,不做大模型处理
),
StructuredTool.from_function(
func=wolfram,

View File

@ -1,11 +1,13 @@
from __future__ import annotations
from uuid import UUID
import json
from langchain.schema import AgentFinish, AgentAction
import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional
from typing import Any, Dict, List, Optional
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.schema import AgentFinish, AgentAction
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import AsyncCallbackHandler
def dumps(obj: Dict) -> str:
return json.dumps(obj, ensure_ascii=False)
@ -21,7 +23,7 @@ class Status:
tool_finish: int = 7
class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def __init__(self):
super().__init__()
self.queue = asyncio.Queue()
@ -29,29 +31,31 @@ class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
self.cur_tool = {}
self.out = True
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
print("on_tool_start")
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None,
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
self.cur_tool = {
"tool_name": serialized["name"],
"input_str": input_str,
"output_str": "",
"status": Status.agent_action,
"run_id": run_id.hex,
"llm_token": "",
"final_answer": "",
"error": "",
}
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
tags: List[str] | None = None, **kwargs: Any) -> None:
self.out = True
self.cur_tool.update(
status=Status.tool_finish,
output_str=output.replace("Answer:", ""),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
print("on_tool_end")
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
self.cur_tool.update(
@ -134,29 +138,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
tool_input=action.tool_input,
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_agent_finish(
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
if "Thought:" in finish.return_values["output"]:
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
self.cur_tool.update(
status=Status.agent_finish,
agent_finish=finish.return_values["output"],
)
self.queue.put_nowait(dumps(self.cur_tool))
async def aiter(self) -> AsyncIterator[str]:
while not self.queue.empty() or not self.done.is_set():
done, other = await asyncio.wait(
[
asyncio.ensure_future(self.queue.get()),
asyncio.ensure_future(self.done.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)
if other:
other.pop().cancel()
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
if token_or_done is True:
break
yield token_or_done

View File

@ -6,12 +6,11 @@ from fastapi import Body
from fastapi.responses import StreamingResponse
from langchain.agents import initialize_agent, AgentType
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableBranch
from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableBranch
from server.agent.agent_factory import initialize_glm3_agent
from server.agent.tools_factory.tools_registry import all_tools
@ -57,48 +56,55 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
elif conversation_id and history_len > 0:
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
message_limit=history_len)
memory = ConversationBufferDBMemory(
conversation_id=conversation_id,
llm=models["llm_model"],
message_limit=history_len
)
else:
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
chain = LLMChain(
prompt=chat_prompt,
llm=models["llm_model"],
memory=memory
)
classifier_chain = (
PromptTemplate.from_template(prompts["preprocess_model"])
| models["preprocess_model"]
| StrOutputParser()
)
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
callback_manager=BaseCallbackManager(handlers=callbacks),
verbose=True,
)
else:
agent_executor = initialize_agent(
llm=models["action_model"],
tools=tools,
callbacks=callbacks,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
memory=memory,
verbose=True,
)
agent_use = False
if agent_use:
branch = RunnableBranch(
(lambda x: "1" in x["topic"].lower(), agent_executor),
chain
)
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
else:
# full_chain = ({"input": lambda x: x["input"]} | chain)
if "action_model" in models and tools:
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
# callback_manager=BaseCallbackManager(handlers=callbacks),
verbose=True,
)
else:
agent_executor = initialize_agent(
llm=models["action_model"],
tools=tools,
callbacks=callbacks,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
memory=memory,
verbose=True,
)
# branch = RunnableBranch(
# (lambda x: "1" in x["topic"].lower(), agent_executor),
# chain
# )
# full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
else:
full_chain = ({"input": lambda x: x["input"]} | chain)
return full_chain
@ -149,7 +155,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
# Execute Chain
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
task = asyncio.create_task(
wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
if stream:
async for chunk in callback.aiter():
data = json.loads(chunk)

View File

@ -508,30 +508,33 @@ def set_httpx_config(
urllib.request.getproxies = _get_proxies
def detect_device() -> Literal["cuda", "mps", "cpu"]:
def detect_device() -> Literal["cuda", "mps", "cpu", "xpu"]:
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
import intel_extension_for_pytorch as ipex
if torch.xpu.get_device_properties(0):
return "xpu"
except:
pass
return "cpu"
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
device = device or LLM_DEVICE
# if device.isdigit():
# return "cuda:" + device
if device not in ["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu", "xpu"]:
device = detect_device()
return device
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
device = device or EMBEDDING_DEVICE
if device not in ["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu", "xpu"]:
device = detect_device()
return device

View File

@ -219,7 +219,6 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
@ -391,7 +390,6 @@ def run_model_worker(
kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_event(app, started_event)
if log_level == "ERROR":

View File

@ -10,8 +10,7 @@ from datetime import datetime
import os
import re
import time
from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
from server.knowledge_base.utils import LOADER_DICT
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS)
import uuid
from typing import List, Dict
@ -124,15 +123,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
tools = list(TOOL_CONFIG.keys())
selected_tool_configs = {}
with st.expander("工具栏"):
for tool in tools:
is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool)
if is_selected:
selected_tool_configs[tool] = TOOL_CONFIG[tool]
if st.session_state.get("cur_conv_name") in conv_names:
index = conv_names.index(st.session_state.get("cur_conv_name"))
conversation_name = st.selectbox("当前会话", conv_names, index=index)
@ -177,7 +167,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
for k, v in config_models.get("online", {}).items():
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
available_models.append(k)
llm_models = running_models + available_models + ["openai-api"]
llm_models = running_models + available_models # + ["openai-api"]
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
if cur_llm_model in llm_models:
index = llm_models.index(cur_llm_model)
@ -193,22 +183,39 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
# 传入后端的内容
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
tool_use = True
for key in LLM_MODEL_CONFIG:
if key == 'llm_model':
continue
if key == 'action_model':
first_key = next(iter(LLM_MODEL_CONFIG[key]))
if first_key not in SUPPORT_AGENT_MODELS:
st.warning("不支持Agent的模型无法执行任何工具调用")
tool_use = False
continue
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
# 选择工具
selected_tool_configs = {}
if tool_use:
from configs import prompt_config
import importlib
importlib.reload(prompt_config)
tools = list(prompt_config.TOOL_CONFIG.keys())
with st.expander("工具栏"):
for tool in tools:
is_selected = st.checkbox(tool, value=prompt_config.TOOL_CONFIG[tool]["use"], key=tool)
if is_selected:
selected_tool_configs[tool] = prompt_config.TOOL_CONFIG[tool]
if llm_model is not None:
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
# files = st.file_uploader("上传附件",accept_multiple_files=False)
# type=[i for ls in LOADER_DICT.values() for i in ls],)
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
# print(len(files_upload["audios"])) if files_upload else None
# if dialogue_mode == "文件对话":
@ -355,6 +362,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.reset_history()
st.rerun()
warning_placeholder = st.empty()
with warning_placeholder.container():
st.warning('Running in 8 x A100')
export_btn.download_button(
"导出记录",
"".join(chat_box.export2md()),