2024-03-06 13:44:43 +08:00

85 lines
2.7 KiB
Python

import base64
import json
import os
from PIL import Image
from typing import List
import uuid
from langchain.agents import tool
from pydantic.v1 import BaseModel, Field
import openai
from pydantic.fields import FieldInfo
from configs.basic_config import MEDIA_PATH
from server.utils import MsgType
def get_image_model_config() -> dict:
from configs.model_config import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
model = LLM_MODEL_CONFIG.get("image_model")
if model:
name = list(model.keys())[0]
if config := ONLINE_LLM_MODEL.get(name):
config = {**list(model.values())[0], **config}
config.setdefault("model_name", name)
return config
@tool(return_direct=True)
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
width: int = Field(512, description="生成图片的宽度"),
height: int = Field(512, description="生成图片的高度"),
) -> List[str]:
'''根据用户的描述生成图片'''
# workaround before langchain uprading
if isinstance(n, FieldInfo):
n = n.default
if isinstance(width, FieldInfo):
width = width.default
if isinstance(height, FieldInfo):
height = height.default
model_config = get_image_model_config()
assert model_config is not None, "请正确配置文生图模型"
client = openai.Client(
base_url=model_config["api_base_url"],
api_key=model_config["api_key"],
timeout=600,
)
resp = client.images.generate(prompt=prompt,
n=n,
size=f"{width}*{height}",
response_format="b64_json",
model=model_config["model_name"],
)
images = []
for x in resp.data:
uid = uuid.uuid4().hex
filename = f"image/{uid}.png"
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json))
images.append(filename)
return json.dumps({"message_type": MsgType.IMAGE, "images": images})
if __name__ == "__main__":
from io import BytesIO
from matplotlib import pyplot as plt
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
prompt = "draw a house with trees and river"
prompt = "画一个带树、草、河流的山中小屋"
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
print(params)
image = text2images.invoke(params)[0]
buffer = BytesIO(base64.b64decode(image))
image = Image.open(buffer)
plt.imshow(image)
plt.show()