liunux4odoo 5c650a8dc3
优化目录结构 (#4058)
* 优化目录结构

* 修改一些测试问题

---------

Co-authored-by: glide-the <2533736852@qq.com>
2024-05-22 13:11:45 +08:00

78 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import json
import os
from PIL import Image
from typing import List
import uuid
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import regist_tool, BaseToolOutput
import openai
from chatchat.configs import MEDIA_PATH
from chatchat.server.utils import MsgType
def get_image_model_config() -> dict:
# from chatchat.configs import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
# TODO ONLINE_LLM_MODEL的配置被删除此处业务需要修改
# model = LLM_MODEL_CONFIG.get("image_model")
# if model:
# name = list(model.keys())[0]
# if config := ONLINE_LLM_MODEL.get(name):
# config = {**list(model.values())[0], **config}
# config.setdefault("model_name", name)
# return config
pass
@regist_tool(title="文生图", return_direct=True)
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
width: int = Field(512, description="生成图片的宽度"),
height: int = Field(512, description="生成图片的高度"),
) -> List[str]:
'''根据用户的描述生成图片'''
model_config = get_image_model_config()
assert model_config is not None, "请正确配置文生图模型"
client = openai.Client(
base_url=model_config["api_base_url"],
api_key=model_config["api_key"],
timeout=600,
)
resp = client.images.generate(prompt=prompt,
n=n,
size=f"{width}*{height}",
response_format="b64_json",
model=model_config["model_name"],
)
images = []
for x in resp.data:
uid = uuid.uuid4().hex
filename = f"image/{uid}.png"
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json))
images.append(filename)
return BaseToolOutput({"message_type": MsgType.IMAGE, "images": images}, format="json")
if __name__ == "__main__":
from io import BytesIO
from matplotlib import pyplot as plt
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
prompt = "draw a house with trees and river"
prompt = "画一个带树、草、河流的山中小屋"
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
print(params)
image = text2images.invoke(params)[0]
buffer = BytesIO(base64.b64decode(image))
image = Image.open(buffer)
plt.imshow(image)
plt.show()