支持ChatGLM3-6B (#2021)

* 更新依赖和配置文件中的Agent模型

* 支持基础的glm3_agent
This commit is contained in:
zR 2023-11-12 16:45:50 +08:00 committed by GitHub
parent 3462d06759
commit 91ff0574df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 473 additions and 175 deletions

View File

@ -5,4 +5,4 @@ from .server_config import *
from .prompt_config import *
VERSION = "v0.2.7-preview"
VERSION = "v0.2.8-preview"

View File

@ -259,15 +259,12 @@ VLLM_MODEL_DICT = {
}
# 你认为支持Agent能力的模型可以在这里添加添加后不会出现可视化界面的警告
# 经过我们测试原生支持Agent的模型仅有以下几个
SUPPORT_AGENT_MODEL = [
"azure-api",
"openai-api",
"claude-api",
"zhipu-api",
"qwen-api",
"Qwen",
"baichuan-api",
"agentlm",
"chatglm3",
"xinghuo-api",
]

View File

@ -100,59 +100,13 @@ PROMPT_TEMPLATES = {
Question: {input}
Thought: {agent_scratchpad}
""",
"AgentLM":
"""
<SYS>>\n
You are a helpful, respectful and honest assistant.
</SYS>>\n
Answer the following questions as best you can. If it is in order, you can use some tools appropriately.You have access to the following tools:
{tools}.
Use the following steps and think step by step!:
Question: the input question you must answer1
Thought: you should always think about what to do and what tools to use.
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin! let's think step by step!
# ChatGLM3必须用官方的提示词没有修改空间参数都不会传入进去
"ChatGLM3":
"""
history:
{history}
Question: {input}
Thought: {agent_scratchpad}
""",
"中文版本":
"""
你的知识不一定正确,所以你一定要用提供的工具来思考,并给出用户答案。
你有以下工具可以使用:
{tools}
请请严格按照提供的思维方式来思考所有的关键词都要输出例如ActionAction InputObservation等
```
Question: 用户的提问或者观察到的信息,
Thought: 你应该思考该做什么,是根据工具的结果来回答问题,还是决定使用什么工具。
Action: 需要使用的工具,应该是在[{tool_names}]中的一个。
Action Input: 传入工具的内容
Observation: 工具给出的答案(不是你生成的)
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: 通过工具给出的答案你是否能回答Question。
Final Answer是你的答案
现在,我们开始!
你和用户的历史记录:
History:
{history}
用户开始以提问:
Question: {input}
Thought: {agent_scratchpad}
""",
},
}

View File

@ -1,6 +1,6 @@
# API requirements
langchain>=0.0.329 # 推荐使用最新的Langchain
langchain>=0.0.334 # 推荐使用最新的Langchain
langchain-experimental>=0.0.30
fschat[model_worker]==0.2.32
xformers>=0.0.22.post4
@ -53,7 +53,7 @@ vllm>=0.2.0; sys_platform == "linux"
# WebUI requirements
streamlit~=1.27.0
streamlit~=1.28.1
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.11

View File

@ -1,6 +1,6 @@
# API requirements
langchain>=0.0.329 # 推荐使用最新的Langchain
langchain>=0.0.334 # 推荐使用最新的Langchain
langchain-experimental>=0.0.30
fschat[model_worker]==0.2.32
xformers>=0.0.22.post4

View File

@ -1,4 +1,4 @@
langchain>=0.0.329 # 推荐使用最新的Langchain
langchain>=0.0.334 # 推荐使用最新的Langchain
fschat>=0.2.32
openai
# sentence_transformers
@ -41,7 +41,7 @@ dashscope>=1.10.0 # qwen
numpy~=1.24.4
pandas~=2.0.3
streamlit~=1.27.0
streamlit~=1.28.1
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox==1.1.11

View File

@ -1,6 +1,6 @@
# WebUI requirements
streamlit~=1.27.0
streamlit~=1.28.1
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.11

View File

@ -73,21 +73,40 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
)
self.queue.put_nowait(dumps(self.cur_tool))
# async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# if "Action" in token: ## 减少重复输出
# before_action = token.split("Action")[0]
# self.cur_tool.update(
# status=Status.running,
# llm_token=before_action + "\n",
# )
# self.queue.put_nowait(dumps(self.cur_tool))
#
# self.out = False
#
# if token and self.out:
# self.cur_tool.update(
# status=Status.running,
# llm_token=token,
# )
# self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if "Action" in token: ## 减少重复输出
before_action = token.split("Action")[0]
self.cur_tool.update(
status=Status.running,
llm_token=before_action + "\n",
)
self.queue.put_nowait(dumps(self.cur_tool))
self.out = False
special_tokens = ["Action", "<|observation|>"]
for stoken in special_tokens:
if stoken in token:
before_action = token.split(stoken)[0]
self.cur_tool.update(
status=Status.running,
llm_token=before_action + "\n",
)
self.queue.put_nowait(dumps(self.cur_tool))
self.out = False
break
if token and self.out:
self.cur_tool.update(
status=Status.running,
llm_token=token,
status=Status.running,
llm_token=token,
)
self.queue.put_nowait(dumps(self.cur_tool))

View File

@ -0,0 +1,280 @@
"""
This file is a modified version for ChatGLM3-6B the original ChatGLM3Agent.py file from the langchain repo.
"""
from __future__ import annotations
import yaml
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from typing import Any, List, Sequence, Tuple, Optional, Union
import os
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
import json
import logging
from langchain.agents.agent import AgentOutputParser
from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
PREFIX = """
You can answer using the tools, or answer directly using your knowledge without using the tools.
Respond to the human as helpfully and accurately as possible.
You have access to the following tools:
"""
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Thought:"""
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__)
class StructuredChatOutputParserWithRetries(AgentOutputParser):
"""Output parser with retries for the structured chat agent."""
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
"""The base parser to use."""
output_fixing_parser: Optional[OutputFixingParser] = None
"""The output fixing parser to use."""
def get_format_instructions(self) -> str:
return FORMAT_INSTRUCTIONS
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
special_tokens = ["Action:", "<|observation|>"]
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index]
if "tool_call" in text:
tool_name_end = text.find("```")
tool_name = text[:tool_name_end].strip()
input_para = text.split("='")[-1].split("'")[0]
action_json = {
"action": tool_name,
"action_input": input_para
}
else:
action_json = {
"action": "Final Answer",
"action_input": text
}
action_str = f"""
Action:
```
{json.dumps(action_json, ensure_ascii=False)}
```"""
try:
if self.output_fixing_parser is not None:
parsed_obj: Union[
AgentAction, AgentFinish
] = self.output_fixing_parser.parse(action_str)
else:
parsed_obj = self.base_parser.parse(action_str)
return parsed_obj
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
@property
def _type(self) -> str:
return "structured_chat_ChatGLM3_6b_with_retries"
class StructuredGLM3ChatAgent(Agent):
"""Structured Chat Agent."""
output_parser: AgentOutputParser = Field(
default_factory=StructuredChatOutputParserWithRetries
)
"""Output parser for the agent."""
@property
def observation_prefix(self) -> str:
"""Prefix to append the ChatGLM3-6B observation with."""
return "Observation:"
@property
def llm_prefix(self) -> str:
"""Prefix to append the llm call with."""
return "Thought:"
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
pass
@classmethod
def _get_default_output_parser(
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
) -> AgentOutputParser:
return StructuredChatOutputParserWithRetries(llm=llm)
@property
def _stop(self) -> List[str]:
return ["```<observation>"]
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
def tool_config_from_file(tool_name, directory="server/agent/tools/"):
"""search tool yaml and return json format"""
file_path = os.path.join(directory, f"{tool_name.lower()}.yaml")
try:
with open(file_path, 'r', encoding='utf-8') as file:
return yaml.safe_load(file)
except FileNotFoundError:
print(f"File not found: {file_path}")
return None
except Exception as e:
print(f"An error occurred while reading {file_path}: {e}")
return None
tools_json = []
tool_names = ""
for tool in tools:
tool_config = tool_config_from_file(tool.name)
if tool_config:
tools_json.append(tool_config)
tool_names.join(tool.name + ", ")
formatted_tools = "\n".join([
json.dumps(tool, ensure_ascii=False).replace("\"", "\\\"").replace("{", "{{").replace("}", "}}")
for tool in tools_json
])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)
@property
def _agent_type(self) -> str:
raise ValueError
def initialize_glm3_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
callback_manager: Optional[BaseCallbackManager] = None,
agent_kwargs: Optional[dict] = None,
*,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
llm, tools, callback_manager=callback_manager, **agent_kwargs
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callback_manager=callback_manager,
tags=tags_,
**kwargs,
)

View File

@ -54,21 +54,6 @@ class CustomOutputParser(AgentOutputParser):
action = parts[1].split("Action Input:")[0].strip()
action_input = parts[1].split("Action Input:")[1].strip()
# 原来的正则化检查方式,更严格,但是成功率更低
# regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
# print("llm_output",llm_output)
# match = re.search(regex, llm_output, re.DOTALL)
# print("match",match)
# if not match:
# return AgentFinish(
# return_values={"output": f"调用agent失败: `{llm_output}`"},
# log=llm_output,
# )
# action = match.group(1).strip()
# action_input = match.group(2)
# Return the action and action input
try:
ans = AgentAction(
tool=action,

View File

@ -1,11 +1,11 @@
## 导入所有的工具类
from .search_knowledge_simple import knowledge_search_simple
from .search_all_knowledge_once import knowledge_search_once, KnowledgeSearchInput
from .search_all_knowledge_more import knowledge_search_more, KnowledgeSearchInput
from .search_knowledgebase_simple import search_knowledgebase_simple
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
from .calculate import calculate, CalculatorInput
from .translator import translate, TranslateInput
from .weather import weathercheck, WhetherSchema
from .weather_check import weathercheck, WhetherSchema
from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput
from .youtube import youtube_search, YoutubeInput
from .search_youtube import search_youtube, YoutubeInput
from .arxiv import arxiv, ArxivInput

View File

@ -0,0 +1,9 @@
# LangChain 的 ArxivQueryRun 工具
from pydantic import BaseModel, Field
from langchain.tools.arxiv.tool import ArxivQueryRun
def arxiv(query: str):
tool = ArxivQueryRun()
return tool.run(tool_input=query)
class ArxivInput(BaseModel):
query: str = Field(description="The search query title")

View File

@ -0,0 +1,10 @@
name: arxiv
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
parameters:
type: object
properties:
query:
type: string
description: The search query title
required:
- query

View File

@ -0,0 +1,10 @@
name: calculate
description: Useful for when you need to answer questions about simple calculations
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -29,7 +29,7 @@ def search_internet(query: str):
return asyncio.run(search_engine_iter(query))
class SearchInternetInput(BaseModel):
location: str = Field(description="需要查询的内容")
location: str = Field(description="Query for Internet search")
if __name__ == "__main__":

View File

@ -0,0 +1,10 @@
name: search_internet
description: Use this tool to surf internet and get information
parameters:
type: object
properties:
query:
type: string
description: Query for Internet search
required:
- query

View File

@ -266,17 +266,17 @@ class LLMKnowledgeChain(LLMChain):
return cls(llm_chain=llm_chain, **kwargs)
def knowledge_search_more(query: str):
def search_knowledgebase_complex(query: str):
model = model_container.MODEL
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_knowledge.run(query)
return ans
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="知识库查询的内容")
location: str = Field(description="The query to be searched")
if __name__ == "__main__":
result = knowledge_search_more("机器人和大数据在代码教学上有什么区别")
result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
print(result)
# 这是一个正常的切割

View File

@ -0,0 +1,10 @@
name: search_knowledgebase_complex
description: Use this tool to search local knowledgebase and get information
parameters:
type: object
properties:
query:
type: string
description: The query to be searched
required:
- query

View File

@ -218,7 +218,7 @@ class LLMKnowledgeChain(LLMChain):
return cls(llm_chain=llm_chain, **kwargs)
def knowledge_search_once(query: str):
def search_knowledgebase_once(query: str):
model = model_container.MODEL
llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_knowledge.run(query)
@ -226,9 +226,9 @@ def knowledge_search_once(query: str):
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="知识库查询的内容")
location: str = Field(description="The query to be searched")
if __name__ == "__main__":
result = knowledge_search_once("大数据的男女比例")
result = search_knowledgebase_once("大数据的男女比例")
print(result)

View File

@ -23,10 +23,10 @@ async def search_knowledge_base_iter(database: str, query: str) -> str:
docs = data["docs"]
return contents
def knowledge_search_simple(query: str):
def search_knowledgebase_simple(query: str):
return asyncio.run(search_knowledge_base_iter(query))
if __name__ == "__main__":
result = knowledge_search_simple("大数据男女比例")
result = search_knowledgebase_simple("大数据男女比例")
print("答案:",result)

View File

@ -1,9 +1,9 @@
# Langchain 自带的 YouTube 搜索工具封装
from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
def youtube_search(query: str):
def search_youtube(query: str):
tool = YouTubeSearchTool()
return tool.run(tool_input=query)
class YoutubeInput(BaseModel):
location: str = Field(description="要搜索视频关键字")
location: str = Field(description="Query for Videos search")

View File

@ -0,0 +1,10 @@
name: search_youtube
description: Use this tools to search youtube videos
parameters:
type: object
properties:
query:
type: string
description: Query for Videos search
required:
- query

View File

@ -0,0 +1,10 @@
name: shell
description: Use Linux Shell to execute Linux commands
parameters:
type: object
properties:
query:
type: string
description: The command to execute
required:
- query

View File

@ -1,38 +0,0 @@
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from server.agent import model_container
from pydantic import BaseModel, Field
_PROMPT_TEMPLATE = '''
# 指令
接下来作为一个专业的翻译专家当我给出句子或段落时你将提供通顺且具有可读性的对应语言的翻译注意
1. 确保翻译结果流畅且易于理解
2. 无论提供的是陈述句或疑问句只进行翻译
3. 不添加与原文无关的内容
问题: ${{用户需要翻译的原文和目标语言}}
答案: 你翻译结果
现在这是我的问题
问题: {question}
'''
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)
def translate(query: str):
model = model_container.MODEL
llm_translate = LLMChain(llm=model, prompt=PROMPT)
ans = llm_translate.run(query)
return ans
class TranslateInput(BaseModel):
location: str = Field(description="需要被翻译的内容")
if __name__ == "__main__":
result = translate("Can Love remember the question and the answer? 这句话如何诗意的翻译成中文")
print("答案:",result)

View File

@ -0,0 +1,10 @@
name: weather_check
description: Use Weather API to get weather information
parameters:
type: object
properties:
query:
type: string
description: City name,include city and county,like "厦门市思明区"
required:
- query

View File

@ -0,0 +1,10 @@
name: wolfram
description: Useful for when you need to calculate difficult math formulas
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -6,50 +6,50 @@ from server.agent.tools import *
tools = [
Tool.from_function(
func=calculate,
name="计算器工具",
description="进行简单的数学运算, 只是简单的, 使用Wolfram数学工具进行更复杂的运算",
name="calculate",
description="Useful for when you need to answer questions about simple calculations",
args_schema=CalculatorInput,
),
Tool.from_function(
func=translate,
name="翻译工具",
description="如果你无法访问互联网,并且需要翻译各种语言,应该使用这个工具",
args_schema=TranslateInput,
func=arxiv,
name="arxiv",
description="A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.",
args_schema=ArxivInput,
),
Tool.from_function(
func=weathercheck,
name="天气查询工具",
description="无需访问互联网使用这个工具查询中国各地未来24小时的天气",
name="weather_check",
description="",
args_schema=WhetherSchema,
),
Tool.from_function(
func=shell,
name="shell工具",
description="使用命令行工具输出",
name="shell",
description="Use Shell to execute Linux commands",
args_schema=ShellInput,
),
Tool.from_function(
func=knowledge_search_more,
name="知识库查询工具",
description="优先访问知识库来获取答案",
func=search_knowledgebase_complex,
name="search_knowledgebase_complex",
description="Use Use this tool to search local knowledgebase and get information",
args_schema=KnowledgeSearchInput,
),
Tool.from_function(
func=search_internet,
name="互联网查询工具",
description="如果你无法访问互联网这个工具可以帮助你访问Bing互联网来解答问题",
name="search_internet",
description="Use this tool to use bing search engine to search the internet",
args_schema=SearchInternetInput,
),
Tool.from_function(
func=wolfram,
name="Wolfram数学工具",
description="高级的数学运算工具,能够完成非常复杂的数学问题",
name="Wolfram",
description="Useful for when you need to calculate difficult formulas",
args_schema=WolframInput,
),
Tool.from_function(
func=youtube_search,
name="Youtube搜索工具",
description="使用这个工具在Youtube上搜索视频",
func=search_youtube,
name="search_youtube",
description="use this tools to search youtube videos",
args_schema=YoutubeInput,
),
]

View File

@ -1,14 +1,16 @@
from langchain.memory import ConversationBufferWindowMemory
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import AgentExecutor, LLMSingleActionAgent, initialize_agent, BaseMultiActionAgent
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional, Dict
from typing import AsyncIterable, Optional
import asyncio
from typing import List
from server.chat.utils import History
@ -73,12 +75,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
)
output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:", "Observation:", "<|im_end|>", "<|observation|>"], # Qwen模型中使用这个
allowed_tools=tool_names,
)
# 把history转成agent的memory
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history:
@ -89,11 +85,27 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
else:
# 添加AI消息
memory.chat_memory.add_ai_message(message.content)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
if "chatglm3" in model_container.MODEL.model_name:
agent_executor = initialize_glm3_agent(
llm=model,
tools=tools,
callback_manager=None,
verbose=True,
memory=memory,
)
else:
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:", "Observation"],
allowed_tools=tool_names,
)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
while True:
try:
task = asyncio.create_task(wrap_done(

View File

@ -63,7 +63,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
]
dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes,
index=0,
index=3,
on_change=on_mode_change,
key="dialogue_mode",
)