liunux4odoo 9818bd2a88
- 重写 tool 部分: (#3553)
- 简化 tool 的定义方式
    - 所有 tool 和 tool_config 支持热加载
    - 修复:json_schema_extra warning
2024-03-28 13:08:51 +08:00

53 lines
2.0 KiB
Python

from chatchat.configs import logger
from chatchat.server.utils import get_tool_config
class ModelContainer:
def __init__(self):
self.model = None
self.metadata = None
self.vision_model = None
self.vision_tokenizer = None
self.audio_tokenizer = None
self.audio_model = None
vqa_config = get_tool_config("vqa_processor")
if vqa_config["use"]:
try:
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
import torch
self.vision_tokenizer = LlamaTokenizer.from_pretrained(
vqa_config["tokenizer_path"],
trust_remote_code=True)
self.vision_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=vqa_config["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(vqa_config["device"]).eval()
except Exception as e:
logger.error(e, exc_info=True)
aqa_config = get_tool_config("vqa_processor")
if aqa_config["use"]:
try:
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
import torch
self.audio_tokenizer = AutoTokenizer.from_pretrained(
aqa_config["tokenizer_path"],
trust_remote_code=True
)
self.audio_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=aqa_config["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).to(
aqa_config["device"]
).eval()
except Exception as e:
logger.error(e, exc_info=True)
container = ModelContainer()