liunux4odoo 42aa900566
优化工具定义;添加 openai 兼容的统一 chat 接口 (#3570)
- 修复:
    - Qwen Agent 的 OutputParser 不再抛出异常,遇到非 COT 文本直接返回
    - CallbackHandler 正确处理工具调用信息

- 重写 tool 定义方式:
    - 添加 regist_tool 简化 tool 定义:
        - 可以指定一个用户友好的名称
        - 自动将函数的 __doc__ 作为 tool.description
	- 支持用 Field 定义参数,不再需要额外定义 ModelSchema
        - 添加 BaseToolOutput 封装 tool	返回结果,以便同时获取原始值、给LLM的字符串值
        - 支持工具热加载(有待测试)

- 增加 openai 兼容的统一 chat 接口,通过 tools/tool_choice/extra_body 不同参数组合支持:
    - Agent 对话
    - 指定工具调用(如知识库RAG)
    - LLM 对话

- 根据后端功能更新 webui
2024-03-29 11:55:32 +08:00

78 lines
2.5 KiB
Python

import base64
import json
import os
from PIL import Image
from typing import List
import uuid
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool, BaseToolOutput
import openai
from chatchat.configs.basic_config import MEDIA_PATH
from chatchat.server.utils import MsgType
def get_image_model_config() -> dict:
from chatchat.configs.model_config import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
model = LLM_MODEL_CONFIG.get("image_model")
if model:
name = list(model.keys())[0]
if config := ONLINE_LLM_MODEL.get(name):
config = {**list(model.values())[0], **config}
config.setdefault("model_name", name)
return config
@regist_tool(title="文生图", return_direct=True)
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
width: int = Field(512, description="生成图片的宽度"),
height: int = Field(512, description="生成图片的高度"),
) -> List[str]:
'''根据用户的描述生成图片'''
model_config = get_image_model_config()
assert model_config is not None, "请正确配置文生图模型"
client = openai.Client(
base_url=model_config["api_base_url"],
api_key=model_config["api_key"],
timeout=600,
)
resp = client.images.generate(prompt=prompt,
n=n,
size=f"{width}*{height}",
response_format="b64_json",
model=model_config["model_name"],
)
images = []
for x in resp.data:
uid = uuid.uuid4().hex
filename = f"image/{uid}.png"
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json))
images.append(filename)
return BaseToolOutput({"message_type": MsgType.IMAGE, "images": images}, format="json")
if __name__ == "__main__":
from io import BytesIO
from matplotlib import pyplot as plt
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
prompt = "draw a house with trees and river"
prompt = "画一个带树、草、河流的山中小屋"
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
print(params)
image = text2images.invoke(params)[0]
buffer = BytesIO(base64.b64decode(image))
image = Image.open(buffer)
plt.imshow(image)
plt.show()