From 17ba4870741a496e535174d3619af2025ff16eef Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Thu, 11 Jan 2024 11:09:18 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=96=87=E7=94=9F=E5=9B=BE?= =?UTF-8?q?=E5=B7=A5=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/model_config.py.example | 10 ++++ configs/prompt_config.py.example | 3 + server/agent/tools_factory/__init__.py | 1 + server/agent/tools_factory/text2image.py | 58 ++++++++++++++++++++ server/agent/tools_factory/tools_registry.py | 2 + server/utils.py | 4 +- 6 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 server/agent/tools_factory/text2image.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index ea7ac90d..6b887c5f 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -55,6 +55,11 @@ LLM_MODEL_CONFIG = { "callbacks": True } }, + "image_model": { + "sd-turbo": { + "size": "1024x1024", + } + } } LLM_DEVICE = "auto" ONLINE_LLM_MODEL = { @@ -64,6 +69,11 @@ ONLINE_LLM_MODEL = { "api_key": "sk-", "openai_proxy": "", }, + "sd-turbo": { + "model_name": "sd-turbo", + "api_base_url": "http://127.0.0.1:9997/v1", + "api_key": "EMPTY", + }, "zhipu-api": { "api_key": "", "version": "chatglm_turbo", diff --git a/configs/prompt_config.py.example b/configs/prompt_config.py.example index 4f9c45f5..ec3ccf74 100644 --- a/configs/prompt_config.py.example +++ b/configs/prompt_config.py.example @@ -171,6 +171,9 @@ TOOL_CONFIG = { "calculate": { "use": False, }, + "text2images": { + "use": False, + }, # Use THUDM/cogvlm-chat-hf as default diff --git a/server/agent/tools_factory/__init__.py b/server/agent/tools_factory/__init__.py index 71367d42..9486a224 100644 --- a/server/agent/tools_factory/__init__.py +++ b/server/agent/tools_factory/__init__.py @@ -6,6 +6,7 @@ from .search_internet import search_internet, SearchInternetInput from .wolfram import wolfram, WolframInput from .search_youtube import search_youtube, YoutubeInput from .arxiv import arxiv, ArxivInput +from .text2image import text2images from .vision_factory import * from .audio_factory import * diff --git a/server/agent/tools_factory/text2image.py b/server/agent/tools_factory/text2image.py new file mode 100644 index 00000000..02d7dd3a --- /dev/null +++ b/server/agent/tools_factory/text2image.py @@ -0,0 +1,58 @@ +from PIL import Image +from typing import List + +from langchain.agents import tool +from langchain.pydantic_v1 import Field +import openai + + +def get_image_model_config() -> dict: + from configs.model_config import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL + + model = LLM_MODEL_CONFIG.get("image_model") + if model: + name = list(model.keys())[0] + if config := ONLINE_LLM_MODEL.get(name): + config = {**list(model.values())[0], **config} + config.setdefault("model_name", name) + return config + + +@tool +def text2images( + prompt: str, + n: int = Field(1, description="需生成图片的数量"), + width: int = Field(512, description="生成图片的宽度"), + height: int = Field(512, description="生成图片的高度"), +) -> List[str]: + '''根据用户的描述生成图片''' + model_config = get_image_model_config() + assert model_config is not None, "请正确配置文生图模型" + + client = openai.Client( + base_url=model_config["api_base_url"], + api_key=model_config["api_key"], + timeout=600, + ) + resp = client.images.generate(prompt=prompt, + n=n, + size=f"{width}*{height}", + response_format="url", + model=model_config["model_name"], + ) + return [x.url for x in resp.data] + + +if __name__ == "__main__": + from matplotlib import pyplot as plt + from pathlib import Path + import sys + sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + + prompt = "draw a house with trees and river" + prompt = "画一个带树、草、河流的山中小屋" + params = text2images.args_schema.parse_obj({"prompt": prompt}).dict() + print(params) + image = text2images.invoke(params)[0] + plt.imshow(image) + plt.show() diff --git a/server/agent/tools_factory/tools_registry.py b/server/agent/tools_factory/tools_registry.py index 7334fdd3..9fb6ee1f 100644 --- a/server/agent/tools_factory/tools_registry.py +++ b/server/agent/tools_factory/tools_registry.py @@ -70,3 +70,5 @@ all_tools = [ ) ] + +all_tools.append(text2images) diff --git a/server/utils.py b/server/utils.py index 990cd13a..0962b017 100644 --- a/server/utils.py +++ b/server/utils.py @@ -447,8 +447,8 @@ def get_prompt_template(type: str, name: str) -> Optional[str]: from configs import prompt_config import importlib - importlib.reload(prompt_config) - return prompt_config.PROMPT_TEMPLATES[type].get(name) + importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载 + return prompt_config.PROMPT_TEMPLATES.get(type, {}).get(name) def set_httpx_config(