添加文生图工具

This commit is contained in:
liunux4odoo 2024-01-11 11:09:18 +08:00
parent 6f155aec1f
commit 17ba487074
6 changed files with 76 additions and 2 deletions

View File

@ -55,6 +55,11 @@ LLM_MODEL_CONFIG = {
"callbacks": True
}
},
"image_model": {
"sd-turbo": {
"size": "1024x1024",
}
}
}
LLM_DEVICE = "auto"
ONLINE_LLM_MODEL = {
@ -64,6 +69,11 @@ ONLINE_LLM_MODEL = {
"api_key": "sk-",
"openai_proxy": "",
},
"sd-turbo": {
"model_name": "sd-turbo",
"api_base_url": "http://127.0.0.1:9997/v1",
"api_key": "EMPTY",
},
"zhipu-api": {
"api_key": "",
"version": "chatglm_turbo",

View File

@ -171,6 +171,9 @@ TOOL_CONFIG = {
"calculate": {
"use": False,
},
"text2images": {
"use": False,
},
# Use THUDM/cogvlm-chat-hf as default

View File

@ -6,6 +6,7 @@ from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput
from .search_youtube import search_youtube, YoutubeInput
from .arxiv import arxiv, ArxivInput
from .text2image import text2images
from .vision_factory import *
from .audio_factory import *

View File

@ -0,0 +1,58 @@
from PIL import Image
from typing import List
from langchain.agents import tool
from langchain.pydantic_v1 import Field
import openai
def get_image_model_config() -> dict:
from configs.model_config import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
model = LLM_MODEL_CONFIG.get("image_model")
if model:
name = list(model.keys())[0]
if config := ONLINE_LLM_MODEL.get(name):
config = {**list(model.values())[0], **config}
config.setdefault("model_name", name)
return config
@tool
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
width: int = Field(512, description="生成图片的宽度"),
height: int = Field(512, description="生成图片的高度"),
) -> List[str]:
'''根据用户的描述生成图片'''
model_config = get_image_model_config()
assert model_config is not None, "请正确配置文生图模型"
client = openai.Client(
base_url=model_config["api_base_url"],
api_key=model_config["api_key"],
timeout=600,
)
resp = client.images.generate(prompt=prompt,
n=n,
size=f"{width}*{height}",
response_format="url",
model=model_config["model_name"],
)
return [x.url for x in resp.data]
if __name__ == "__main__":
from matplotlib import pyplot as plt
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
prompt = "draw a house with trees and river"
prompt = "画一个带树、草、河流的山中小屋"
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
print(params)
image = text2images.invoke(params)[0]
plt.imshow(image)
plt.show()

View File

@ -70,3 +70,5 @@ all_tools = [
)
]
all_tools.append(text2images)

View File

@ -447,8 +447,8 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
from configs import prompt_config
import importlib
importlib.reload(prompt_config)
return prompt_config.PROMPT_TEMPLATES[type].get(name)
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
return prompt_config.PROMPT_TEMPLATES.get(type, {}).get(name)
def set_httpx_config(