liunux4odoo 42aa900566
优化工具定义;添加 openai 兼容的统一 chat 接口 (#3570)
- 修复:
    - Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回
    - CallbackHandler 正确处理工具调用信息

- 重写 tool 定义方式:
    - 添加 regist_tool 简化 tool 定义:
        - 可以指定一个用户友好的名称
        - 自动将函数的 __doc__ 作为 tool.description
	- 支持用 Field 定义参数,不再需要额外定义 ModelSchema
        - 添加 BaseToolOutput 封装 tool	返回结果,以便同时获取原始值、给LLM的字符串值
        - 支持工具热加载(有待测试)

- 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持:
    - Agent 对话
    - 指定工具调用(如知识库RAG)
    - LLM 对话

- 根据后端功能更新 webui
2024-03-29 11:55:32 +08:00

141 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from typing import List, Dict
from fastapi import APIRouter, Request
from langchain.prompts.prompt import PromptTemplate
from chatchat.server.api_server.api_schemas import OpenAIChatInput, MsgType, AgentStatus
from chatchat.server.chat.chat import chat
from chatchat.server.chat.feedback import chat_feedback
from chatchat.server.chat.file_chat import file_chat
from chatchat.server.db.repository import add_message_to_db
from chatchat.server.utils import get_OpenAIClient, get_tool, get_tool_config, get_prompt_template
from .openai_routes import openai_request
chat_router = APIRouter(prefix="/chat", tags=["ChatChat 对话"])
chat_router.post("/chat",
summary="与llm模型对话(通过LLMChain)",
)(chat)
chat_router.post("/feedback",
summary="返回llm模型对话评分",
)(chat_feedback)
chat_router.post("/file_chat",
summary="文件对话"
)(file_chat)
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
async def chat_completions(
request: Request,
body: OpenAIChatInput,
) -> Dict:
'''
请求参数与 openai.chat.completions.create 一致,可以通过 extra_body 传入额外参数
tools 和 tool_choice 可以直接传工具名称,会根据项目里包含的 tools 进行转换
通过不同的参数组合调用不同的 chat 功能:
- tool_choice
- extra_body 中包含 tool_input: 直接调用 tool_choice(tool_input)
- extra_body 中不包含 tool_input: 通过 agent 调用 tool_choice
- tools: agent 对话
- 其它LLM 对话
以后还要考虑其它的组合(如文件对话)
返回与 openai 兼容的 Dict
'''
client = get_OpenAIClient(model_name=body.model, is_async=True)
extra = {**body.model_extra} or {}
for key in list(extra):
delattr(body, key)
# check tools & tool_choice in request body
if isinstance(body.tool_choice, str):
if t := get_tool(body.tool_choice):
body.tool_choice = {"function": {"name": t.name}, "type": "function"}
if isinstance(body.tools, list):
for i in range(len(body.tools)):
if isinstance(body.tools[i], str):
if t := get_tool(body.tools[i]):
body.tools[i] = {
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.args,
}
}
conversation_id = extra.get("conversation_id")
# chat based on result from one choiced tool
if body.tool_choice:
tool = get_tool(body.tool_choice["function"]["name"])
if not body.tools:
body.tools = [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.args,
}
}]
if tool_input := extra.get("tool_input"):
message_id = add_message_to_db(
chat_type="tool_call",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
tool_result = await tool.ainvoke(tool_input)
prompt_template = PromptTemplate.from_template(get_prompt_template("llm_model", "rag"))
body.messages[-1]["content"] = prompt_template.format(context=tool_result, question=body.messages[-1]["content"])
del body.tools
del body.tool_choice
extra_json = {
"message_id": message_id,
"status": None,
}
header = [{**extra_json,
"content": f"知识库参考资料:\n\n{tool_result}\n\n",
"tool_output":tool_result.data,
"is_ref": True,
}]
return await openai_request(client.chat.completions.create, body, extra_json=extra_json, header=header)
# agent chat with tool calls
if body.tools:
message_id = add_message_to_db(
chat_type="agent_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
chat_model_config = {} # TODO: 前端支持配置模型
tool_names = [x["function"]["name"] for x in body.tools]
tool_config = {name: get_tool_config(name) for name in tool_names}
result = await chat(query=body.messages[-1]["content"],
metadata=extra.get("metadata", {}),
conversation_id=extra.get("conversation_id", ""),
message_id=message_id,
history_len=-1,
history=body.messages[:-1],
stream=body.stream,
chat_model_config=extra.get("chat_model_config", chat_model_config),
tool_config=extra.get("tool_config", tool_config),
)
return result
else: # LLM chat directly
message_id = add_message_to_db(
chat_type="llm_chat",
query=body.messages[-1]["content"],
conversation_id=conversation_id
) if conversation_id else None
extra_json = {
"message_id": message_id,
"status": None,
}
return await openai_request(client.chat.completions.create, body, extra_json=extra_json)