mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-31 03:03:22 +08:00
支持XPU,修改了glm3部分agent
This commit is contained in:
parent
d44ce6ce21
commit
6d3d99639e
@ -3,14 +3,6 @@ This file is a modified version for ChatGLM3-6B the original glm3_agent.py file
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import yaml
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||
import os
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||
@ -30,7 +22,8 @@ from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools.base import BaseTool
|
||||
from pydantic.schema import model_schema
|
||||
from langchain_core.callbacks import Callbacks
|
||||
|
||||
|
||||
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -195,6 +188,7 @@ class StructuredGLM3ChatAgent(Agent):
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
@ -9,7 +8,7 @@ from configs import TOOL_CONFIG
|
||||
def search_knowledgebase(query: str, database: str, config: dict):
|
||||
docs = search_docs(
|
||||
query=query,
|
||||
knowledge_base_name=database,
|
||||
knowledge_base_name="samples",
|
||||
top_k=config["top_k"],
|
||||
score_threshold=config["score_threshold"])
|
||||
context = ""
|
||||
@ -22,7 +21,7 @@ def search_knowledgebase(query: str, database: str, config: dict):
|
||||
source_documents.append(text)
|
||||
|
||||
if len(source_documents) == 0:
|
||||
context= "没有找到相关文档,请更换关键词重试"
|
||||
context = "没有找到相关文档,请更换关键词重试"
|
||||
else:
|
||||
for doc in source_documents:
|
||||
context += doc + "\n"
|
||||
@ -37,4 +36,4 @@ class SearchKnowledgeInput(BaseModel):
|
||||
|
||||
def search_local_knowledgebase(database: str, query: str):
|
||||
tool_config = TOOL_CONFIG["search_local_knowledgebase"]
|
||||
return search_knowledgebase(query=query, database=database, config=tool_config)
|
||||
return search_knowledgebase(query=query, database=database, config=tool_config)
|
||||
@ -2,9 +2,9 @@ from langchain_core.tools import StructuredTool
|
||||
from server.agent.tools_factory import *
|
||||
from configs import KB_INFO
|
||||
|
||||
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool."
|
||||
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
||||
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
||||
template_knowledge = template.format(KB_info=KB_info_str)
|
||||
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
||||
|
||||
all_tools = [
|
||||
StructuredTool.from_function(
|
||||
@ -24,6 +24,7 @@ all_tools = [
|
||||
name="shell",
|
||||
description="Use Shell to execute Linux commands",
|
||||
args_schema=ShellInput,
|
||||
# return_direct=True, #是否直接返回,不做大模型处理
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=wolfram,
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from uuid import UUID
|
||||
import json
|
||||
from langchain.schema import AgentFinish, AgentAction
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.schema import AgentFinish, AgentAction
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
|
||||
def dumps(obj: Dict) -> str:
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
@ -21,7 +23,7 @@ class Status:
|
||||
tool_finish: int = 7
|
||||
|
||||
|
||||
class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.queue = asyncio.Queue()
|
||||
@ -29,29 +31,31 @@ class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
self.cur_tool = {}
|
||||
self.out = True
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
print("on_tool_start")
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
self.cur_tool = {
|
||||
"tool_name": serialized["name"],
|
||||
"input_str": input_str,
|
||||
"output_str": "",
|
||||
"status": Status.agent_action,
|
||||
"run_id": run_id.hex,
|
||||
"llm_token": "",
|
||||
"final_answer": "",
|
||||
"error": "",
|
||||
}
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
|
||||
self.out = True
|
||||
self.cur_tool.update(
|
||||
status=Status.tool_finish,
|
||||
output_str=output.replace("Answer:", ""),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
print("on_tool_end")
|
||||
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
@ -134,29 +138,17 @@ class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
tool_input=action.tool_input,
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if "Thought:" in finish.return_values["output"]:
|
||||
finish.return_values["output"] = finish.return_values["output"].replace("Thought:", "")
|
||||
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_finish,
|
||||
agent_finish=finish.return_values["output"],
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def aiter(self) -> AsyncIterator[str]:
|
||||
while not self.queue.empty() or not self.done.is_set():
|
||||
done, other = await asyncio.wait(
|
||||
[
|
||||
asyncio.ensure_future(self.queue.get()),
|
||||
asyncio.ensure_future(self.done.wait()),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if other:
|
||||
other.pop().cancel()
|
||||
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
||||
if token_or_done is True:
|
||||
break
|
||||
yield token_or_done
|
||||
|
||||
@ -6,12 +6,11 @@ from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from langchain.agents import initialize_agent, AgentType
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
|
||||
from server.agent.agent_factory import initialize_glm3_agent
|
||||
from server.agent.tools_factory.tools_registry import all_tools
|
||||
@ -57,48 +56,55 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
elif conversation_id and history_len > 0:
|
||||
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
|
||||
message_limit=history_len)
|
||||
memory = ConversationBufferDBMemory(
|
||||
conversation_id=conversation_id,
|
||||
llm=models["llm_model"],
|
||||
message_limit=history_len
|
||||
)
|
||||
else:
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
|
||||
chain = LLMChain(
|
||||
prompt=chat_prompt,
|
||||
llm=models["llm_model"],
|
||||
memory=memory
|
||||
)
|
||||
classifier_chain = (
|
||||
PromptTemplate.from_template(prompts["preprocess_model"])
|
||||
| models["preprocess_model"]
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
if "chatglm3" in models["action_model"].model_name.lower():
|
||||
agent_executor = initialize_glm3_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
prompt=prompts["action_model"],
|
||||
input_variables=["input", "intermediate_steps", "history"],
|
||||
memory=memory,
|
||||
callback_manager=BaseCallbackManager(handlers=callbacks),
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
agent_executor = initialize_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
memory=memory,
|
||||
verbose=True,
|
||||
)
|
||||
agent_use = False
|
||||
if agent_use:
|
||||
branch = RunnableBranch(
|
||||
(lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||
chain
|
||||
)
|
||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||
else:
|
||||
# full_chain = ({"input": lambda x: x["input"]} | chain)
|
||||
if "action_model" in models and tools:
|
||||
if "chatglm3" in models["action_model"].model_name.lower():
|
||||
agent_executor = initialize_glm3_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
prompt=prompts["action_model"],
|
||||
input_variables=["input", "intermediate_steps", "history"],
|
||||
memory=memory,
|
||||
# callback_manager=BaseCallbackManager(handlers=callbacks),
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
agent_executor = initialize_agent(
|
||||
llm=models["action_model"],
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
memory=memory,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# branch = RunnableBranch(
|
||||
# (lambda x: "1" in x["topic"].lower(), agent_executor),
|
||||
# chain
|
||||
# )
|
||||
# full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
|
||||
else:
|
||||
full_chain = ({"input": lambda x: x["input"]} | chain)
|
||||
return full_chain
|
||||
|
||||
|
||||
@ -149,7 +155,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
|
||||
# Execute Chain
|
||||
|
||||
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
||||
task = asyncio.create_task(
|
||||
wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
||||
if stream:
|
||||
async for chunk in callback.aiter():
|
||||
data = json.loads(chunk)
|
||||
|
||||
@ -508,30 +508,33 @@ def set_httpx_config(
|
||||
urllib.request.getproxies = _get_proxies
|
||||
|
||||
|
||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||
def detect_device() -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.get_device_properties(0):
|
||||
return "xpu"
|
||||
except:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
|
||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||
device = device or LLM_DEVICE
|
||||
# if device.isdigit():
|
||||
# return "cuda:" + device
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
||||
device = device or EMBEDDING_DEVICE
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
@ -219,7 +219,6 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
@ -391,7 +390,6 @@ def run_model_worker(
|
||||
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
||||
model_path = kwargs.get("model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
|
||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||
_set_app_event(app, started_event)
|
||||
if log_level == "ERROR":
|
||||
|
||||
@ -10,8 +10,7 @@ from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
|
||||
from server.knowledge_base.utils import LOADER_DICT
|
||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS)
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
|
||||
@ -124,15 +123,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||
index = 0
|
||||
|
||||
tools = list(TOOL_CONFIG.keys())
|
||||
selected_tool_configs = {}
|
||||
|
||||
with st.expander("工具栏"):
|
||||
for tool in tools:
|
||||
is_selected = st.checkbox(tool, value=TOOL_CONFIG[tool]["use"], key=tool)
|
||||
if is_selected:
|
||||
selected_tool_configs[tool] = TOOL_CONFIG[tool]
|
||||
|
||||
if st.session_state.get("cur_conv_name") in conv_names:
|
||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
||||
@ -177,7 +167,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
for k, v in config_models.get("online", {}).items():
|
||||
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
||||
available_models.append(k)
|
||||
llm_models = running_models + available_models + ["openai-api"]
|
||||
llm_models = running_models + available_models # + ["openai-api"]
|
||||
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
||||
if cur_llm_model in llm_models:
|
||||
index = llm_models.index(cur_llm_model)
|
||||
@ -193,22 +183,39 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
|
||||
# 传入后端的内容
|
||||
model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||
|
||||
tool_use = True
|
||||
for key in LLM_MODEL_CONFIG:
|
||||
if key == 'llm_model':
|
||||
continue
|
||||
if key == 'action_model':
|
||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||
if first_key not in SUPPORT_AGENT_MODELS:
|
||||
st.warning("不支持Agent的模型,无法执行任何工具调用")
|
||||
tool_use = False
|
||||
continue
|
||||
if LLM_MODEL_CONFIG[key]:
|
||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||
model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
||||
|
||||
# 选择工具
|
||||
selected_tool_configs = {}
|
||||
if tool_use:
|
||||
from configs import prompt_config
|
||||
import importlib
|
||||
importlib.reload(prompt_config)
|
||||
|
||||
tools = list(prompt_config.TOOL_CONFIG.keys())
|
||||
with st.expander("工具栏"):
|
||||
for tool in tools:
|
||||
is_selected = st.checkbox(tool, value=prompt_config.TOOL_CONFIG[tool]["use"], key=tool)
|
||||
if is_selected:
|
||||
selected_tool_configs[tool] = prompt_config.TOOL_CONFIG[tool]
|
||||
|
||||
if llm_model is not None:
|
||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||
|
||||
# files = st.file_uploader("上传附件",accept_multiple_files=False)
|
||||
# type=[i for ls in LOADER_DICT.values() for i in ls],)
|
||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||
|
||||
# print(len(files_upload["audios"])) if files_upload else None
|
||||
|
||||
# if dialogue_mode == "文件对话":
|
||||
@ -355,6 +362,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
chat_box.reset_history()
|
||||
st.rerun()
|
||||
|
||||
warning_placeholder = st.empty()
|
||||
with warning_placeholder.container():
|
||||
st.warning('Running in 8 x A100')
|
||||
|
||||
export_btn.download_button(
|
||||
"导出记录",
|
||||
"".join(chat_box.export2md()),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user