2024-03-06 13:32:44 +08:00

173 lines
8.2 KiB
Python

from fastapi import Body
from fastapi.responses import StreamingResponse
from langchain.agents import initialize_agent, AgentType
from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableBranch
from server.agent.agent_factory import initialize_glm3_agent
from server.agent.tools_factory.tools_registry import all_tools
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
from typing import AsyncIterable, Dict
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Union
from server.chat.utils import History
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
def create_models_from_config(configs, callbacks):
if configs is None:
configs = {}
models = {}
prompts = {}
for model_type, model_configs in configs.items():
for model_name, params in model_configs.items():
callback = callbacks if params.get('callbacks', False) else None
model_instance = get_ChatOpenAI(
model_name=model_name,
temperature=params.get('temperature', 0.5),
max_tokens=params.get('max_tokens', 1000),
callbacks=callback
)
models[model_type] = model_instance
prompt_name = params.get('prompt_name', 'default')
prompt_template = get_prompt_template(type=model_type, name=prompt_name)
prompts[model_type] = prompt_template
return models, prompts
# 在这里写构建逻辑
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
memory = None
chat_prompt = None
if history:
history = [History.from_data(h) for h in history]
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
elif conversation_id and history_len > 0:
memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"],
message_limit=history_len)
else:
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory)
classifier_chain = (
PromptTemplate.from_template(prompts["preprocess_model"])
| models["preprocess_model"]
| StrOutputParser()
)
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
tools=tools,
prompt=prompts["action_model"],
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
callback_manager=BaseCallbackManager(handlers=callbacks),
verbose=True,
)
else:
agent_executor = initialize_agent(
llm=models["action_model"],
tools=tools,
callbacks=callbacks,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
memory=memory,
verbose=True,
)
agent_use = False
if agent_use:
branch = RunnableBranch(
(lambda x: "1" in x["topic"].lower(), agent_executor),
chain
)
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
else:
full_chain = ({"input": lambda x: x["input"]} | chain)
return full_chain
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_config: Dict = Body({}, description="LLM 模型配置。"),
tool_config: Dict = Body({}, description="工具配置"),
):
async def chat_iterator() -> AsyncIterable[str]:
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) if conversation_id else None
callback = CustomAsyncIteratorCallbackHandler()
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
tools = [tool for tool in all_tools if tool.name in tool_config]
# 构建完整的Chain
full_chain = create_models_chains(prompts=prompts,
models=models,
conversation_id=conversation_id,
tools=tools,
callbacks=callbacks,
history=history,
history_len=history_len,
metadata=metadata)
# 执行完整的Chain
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
if stream:
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == Status.start:
continue
elif data["status"] == Status.agent_action:
tool_info = {
"tool_name": data["tool_name"],
"tool_input": data["tool_input"]
}
yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False)
elif data["status"] == Status.agent_finish:
yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id},
ensure_ascii=False)
else:
yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False)
else:
text = ""
agent_finish = ""
tool_info = None
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == Status.agent_action:
tool_info = {
"tool_name": data["tool_name"],
"tool_input": data["tool_input"]
}
if data["status"] == Status.agent_finish:
agent_finish = data["agent_finish"]
else:
text += data["llm_token"]
if tool_info:
yield json.dumps(
{"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id},
ensure_ascii=False)
else:
yield json.dumps(
{"text": text, "message_id": message_id},
ensure_ascii=False)
await task
return EventSourceResponse(chat_iterator())