mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 07:53:29 +08:00
添加文生图工具
This commit is contained in:
parent
6f155aec1f
commit
17ba487074
@ -55,6 +55,11 @@ LLM_MODEL_CONFIG = {
|
|||||||
"callbacks": True
|
"callbacks": True
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"image_model": {
|
||||||
|
"sd-turbo": {
|
||||||
|
"size": "1024x1024",
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
LLM_DEVICE = "auto"
|
LLM_DEVICE = "auto"
|
||||||
ONLINE_LLM_MODEL = {
|
ONLINE_LLM_MODEL = {
|
||||||
@ -64,6 +69,11 @@ ONLINE_LLM_MODEL = {
|
|||||||
"api_key": "sk-",
|
"api_key": "sk-",
|
||||||
"openai_proxy": "",
|
"openai_proxy": "",
|
||||||
},
|
},
|
||||||
|
"sd-turbo": {
|
||||||
|
"model_name": "sd-turbo",
|
||||||
|
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||||
|
"api_key": "EMPTY",
|
||||||
|
},
|
||||||
"zhipu-api": {
|
"zhipu-api": {
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"version": "chatglm_turbo",
|
"version": "chatglm_turbo",
|
||||||
|
|||||||
@ -171,6 +171,9 @@ TOOL_CONFIG = {
|
|||||||
"calculate": {
|
"calculate": {
|
||||||
"use": False,
|
"use": False,
|
||||||
},
|
},
|
||||||
|
"text2images": {
|
||||||
|
"use": False,
|
||||||
|
},
|
||||||
|
|
||||||
# Use THUDM/cogvlm-chat-hf as default
|
# Use THUDM/cogvlm-chat-hf as default
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from .search_internet import search_internet, SearchInternetInput
|
|||||||
from .wolfram import wolfram, WolframInput
|
from .wolfram import wolfram, WolframInput
|
||||||
from .search_youtube import search_youtube, YoutubeInput
|
from .search_youtube import search_youtube, YoutubeInput
|
||||||
from .arxiv import arxiv, ArxivInput
|
from .arxiv import arxiv, ArxivInput
|
||||||
|
from .text2image import text2images
|
||||||
|
|
||||||
from .vision_factory import *
|
from .vision_factory import *
|
||||||
from .audio_factory import *
|
from .audio_factory import *
|
||||||
|
|||||||
58
server/agent/tools_factory/text2image.py
Normal file
58
server/agent/tools_factory/text2image.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from PIL import Image
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.agents import tool
|
||||||
|
from langchain.pydantic_v1 import Field
|
||||||
|
import openai
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_model_config() -> dict:
|
||||||
|
from configs.model_config import LLM_MODEL_CONFIG, ONLINE_LLM_MODEL
|
||||||
|
|
||||||
|
model = LLM_MODEL_CONFIG.get("image_model")
|
||||||
|
if model:
|
||||||
|
name = list(model.keys())[0]
|
||||||
|
if config := ONLINE_LLM_MODEL.get(name):
|
||||||
|
config = {**list(model.values())[0], **config}
|
||||||
|
config.setdefault("model_name", name)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def text2images(
|
||||||
|
prompt: str,
|
||||||
|
n: int = Field(1, description="需生成图片的数量"),
|
||||||
|
width: int = Field(512, description="生成图片的宽度"),
|
||||||
|
height: int = Field(512, description="生成图片的高度"),
|
||||||
|
) -> List[str]:
|
||||||
|
'''根据用户的描述生成图片'''
|
||||||
|
model_config = get_image_model_config()
|
||||||
|
assert model_config is not None, "请正确配置文生图模型"
|
||||||
|
|
||||||
|
client = openai.Client(
|
||||||
|
base_url=model_config["api_base_url"],
|
||||||
|
api_key=model_config["api_key"],
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
resp = client.images.generate(prompt=prompt,
|
||||||
|
n=n,
|
||||||
|
size=f"{width}*{height}",
|
||||||
|
response_format="url",
|
||||||
|
model=model_config["model_name"],
|
||||||
|
)
|
||||||
|
return [x.url for x in resp.data]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
|
||||||
|
|
||||||
|
prompt = "draw a house with trees and river"
|
||||||
|
prompt = "画一个带树、草、河流的山中小屋"
|
||||||
|
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
|
||||||
|
print(params)
|
||||||
|
image = text2images.invoke(params)[0]
|
||||||
|
plt.imshow(image)
|
||||||
|
plt.show()
|
||||||
@ -70,3 +70,5 @@ all_tools = [
|
|||||||
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
all_tools.append(text2images)
|
||||||
|
|||||||
@ -447,8 +447,8 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
|
|||||||
|
|
||||||
from configs import prompt_config
|
from configs import prompt_config
|
||||||
import importlib
|
import importlib
|
||||||
importlib.reload(prompt_config)
|
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
||||||
return prompt_config.PROMPT_TEMPLATES[type].get(name)
|
return prompt_config.PROMPT_TEMPLATES.get(type, {}).get(name)
|
||||||
|
|
||||||
|
|
||||||
def set_httpx_config(
|
def set_httpx_config(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user