from fastapi import Body from fastapi.responses import StreamingResponse from langchain.agents import initialize_agent, AgentType from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableBranch from server.agent.agent_factory import initialize_glm3_agent from server.agent.tools_factory.tools_registry import all_tools from server.utils import wrap_done, get_ChatOpenAI from langchain.chains import LLMChain from typing import AsyncIterable, Dict import asyncio import json from langchain.prompts.chat import ChatPromptTemplate from typing import List, Union from server.chat.utils import History from langchain.prompts import PromptTemplate from server.utils import get_prompt_template from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.db.repository import add_message_to_db from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler def create_models_from_config(configs: dict = {}, callbacks: list = []): models = {} prompts = {} for model_type, model_configs in configs.items(): for model_name, params in model_configs.items(): callback = callbacks if params.get('callbacks', False) else None model_instance = get_ChatOpenAI( model_name=model_name, temperature=params.get('temperature', 0.5), max_tokens=params.get('max_tokens', 1000), callbacks=callback ) models[model_type] = model_instance prompt_name = params.get('prompt_name', 'default') prompt_template = get_prompt_template(type=model_type, name=prompt_name) prompts[model_type] = prompt_template return models, prompts async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), conversation_id: str = Body("", description="对话框ID"), history_len: int = Body(-1, description="从数据库中取历史消息的数量"), history: Union[int, List[History]] = Body([], description="历史对话,设为一个整数可以从数据库中读取历史消息", examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), model_config: Dict = Body({}, description="LLM 模型配置。"), tool_config: Dict = Body({}, description="工具配置"), ): async def chat_iterator() -> AsyncIterable[str]: nonlocal history memory = None message_id = None chat_prompt = None callback = CustomAsyncIteratorCallbackHandler() callbacks = [callback] models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config) if conversation_id: message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) tools = [tool for tool in all_tools if tool.name in tool_config] if history: history = [History.from_data(h) for h in history] input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_template() for i in history] + [input_msg]) elif conversation_id and history_len > 0: memory = ConversationBufferDBMemory(conversation_id=conversation_id, llm=models["llm_model"], message_limit=history_len) else: input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_msg]) chain = LLMChain(prompt=chat_prompt, llm=models["llm_model"], memory=memory) classifier_chain = ( PromptTemplate.from_template(prompts["preprocess_model"]) | models["preprocess_model"] | StrOutputParser() ) if "chatglm3" in models["action_model"].model_name.lower(): agent_executor = initialize_glm3_agent( llm=models["action_model"], tools=tools, prompt=prompts["action_model"], input_variables=["input", "intermediate_steps", "history"], memory=memory, callback_manager=BaseCallbackManager(handlers=callbacks), verbose=True, ) else: agent_executor = initialize_agent( llm=models["action_model"], tools=tools, callbacks=callbacks, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, memory=memory, verbose=True, ) branch = RunnableBranch( (lambda x: "1" in x["topic"].lower(), agent_executor), chain ) full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch) task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done)) if stream: async for chunk in callback.aiter(): data = json.loads(chunk) if data["status"] == Status.start: continue elif data["status"] == Status.agent_action: tool_info = { "tool_name": data["tool_name"], "tool_input": data["tool_input"] } yield json.dumps({"agent_action": tool_info, "message_id": message_id}, ensure_ascii=False) elif data["status"] == Status.agent_finish: yield json.dumps({"agent_finish": data["agent_finish"], "message_id": message_id}, ensure_ascii=False) else: yield json.dumps({"text": data["llm_token"], "message_id": message_id}, ensure_ascii=False) else: text = "" agent_finish = "" tool_info = None async for chunk in callback.aiter(): # Use server-sent-events to stream the response data = json.loads(chunk) if data["status"] == Status.agent_action: tool_info = { "tool_name": data["tool_name"], "tool_input": data["tool_input"] } if data["status"] == Status.agent_finish: agent_finish = data["agent_finish"] else: text += data["llm_token"] if tool_info: yield json.dumps( {"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id}, ensure_ascii=False) else: yield json.dumps( {"text": text, "message_id": message_id}, ensure_ascii=False) await task return EventSourceResponse(chat_iterator())