mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-01 11:53:24 +08:00
添加文生图工具
This commit is contained in:
parent
6f155aec1f
commit
17ba487074
@ -55,6 +55,11 @@ LLM_MODEL_CONFIG = {
|
||||
"callbacks": True
|
||||
}
|
||||
},
|
||||
"image_model": {
|
||||
"sd-turbo": {
|
||||
"size": "1024x1024",
|
||||
}
|
||||
}
|
||||
}
|
||||
LLM_DEVICE = "auto"
|
||||
ONLINE_LLM_MODEL = {
|
||||
@ -64,6 +69,11 @@ ONLINE_LLM_MODEL = {
|
||||
"api_key": "sk-",
|
||||
"openai_proxy": "",
|
||||
},
|
||||
"sd-turbo": {
|
||||
"model_name": "sd-turbo",
|
||||
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||
"api_key": "EMPTY",
|
||||
},
|
||||
"zhipu-api": {
|
||||
"api_key": "",
|
||||
"version": "chatglm_turbo",
|
||||
|
||||
@ -171,6 +171,9 @@ TOOL_CONFIG = {
|
||||
"calculate": {
|
||||
"use": False,
|
||||
},
|
||||
"text2images": {
|
||||
"use": False,
|
||||
},
|
||||
|
||||
# Use THUDM/cogvlm-chat-hf as default
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from .search_internet import search_internet, SearchInternetInput
|
||||
from .wolfram import wolfram, WolframInput
|
||||
from .search_youtube import search_youtube, YoutubeInput
|
||||
from .arxiv import arxiv, ArxivInput
|
||||
from .text2image import text2images
|
||||
|
||||
from .vision_factory import *
|
||||
from .audio_factory import *
|
||||
|
||||
58
server/agent/tools_factory/text2image.py
Normal file
58
server/agent/tools_factory/text2image.py
Normal file
@ -0,0 +1,58 @@
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
|
||||
from langchain.agents import tool
|
||||
from langchain.pydantic_v1 import Field
|
||||
import openai
|
||||
|
||||
|
||||
def get_image_model_config() -> dict:
|
||||
from configs.model_config import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
|
||||
|
||||
model = LLM_MODEL_CONFIG.get("image_model")
|
||||
if model:
|
||||
name = list(model.keys())[0]
|
||||
if config := ONLINE_LLM_MODEL.get(name):
|
||||
config = {**list(model.values())[0], **config}
|
||||
config.setdefault("model_name", name)
|
||||
return config
|
||||
|
||||
|
||||
@tool
|
||||
def text2images(
|
||||
prompt: str,
|
||||
n: int = Field(1, description="需生成图片的数量"),
|
||||
width: int = Field(512, description="生成图片的宽度"),
|
||||
height: int = Field(512, description="生成图片的高度"),
|
||||
) -> List[str]:
|
||||
'''根据用户的描述生成图片'''
|
||||
model_config = get_image_model_config()
|
||||
assert model_config is not None, "请正确配置文生图模型"
|
||||
|
||||
client = openai.Client(
|
||||
base_url=model_config["api_base_url"],
|
||||
api_key=model_config["api_key"],
|
||||
timeout=600,
|
||||
)
|
||||
resp = client.images.generate(prompt=prompt,
|
||||
n=n,
|
||||
size=f"{width}*{height}",
|
||||
response_format="url",
|
||||
model=model_config["model_name"],
|
||||
)
|
||||
return [x.url for x in resp.data]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from matplotlib import pyplot as plt
|
||||
from pathlib import Path
|
||||
import sys
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
|
||||
|
||||
prompt = "draw a house with trees and river"
|
||||
prompt = "画一个带树、草、河流的山中小屋"
|
||||
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
|
||||
print(params)
|
||||
image = text2images.invoke(params)[0]
|
||||
plt.imshow(image)
|
||||
plt.show()
|
||||
@ -70,3 +70,5 @@ all_tools = [
|
||||
|
||||
)
|
||||
]
|
||||
|
||||
all_tools.append(text2images)
|
||||
|
||||
@ -447,8 +447,8 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
|
||||
|
||||
from configs import prompt_config
|
||||
import importlib
|
||||
importlib.reload(prompt_config)
|
||||
return prompt_config.PROMPT_TEMPLATES[type].get(name)
|
||||
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
||||
return prompt_config.PROMPT_TEMPLATES.get(type, {}).get(name)
|
||||
|
||||
|
||||
def set_httpx_config(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user