统一模型类型编码

This commit is contained in:
glide-the 2024-05-21 15:39:15 +08:00
parent 1f378256e4
commit 7b6cc81ff5
9 changed files with 162 additions and 13 deletions

View File

@ -69,10 +69,7 @@ LLM_MODEL_CONFIG = {
"sd-turbo": {
"size": "256*256",
}
},
"multimodal_model": {
"qwen-vl": {}
},
}
}
# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力启动后下面变量会自动增加相应的平台
@ -116,7 +113,9 @@ MODEL_PLATFORMS = [
"Embedding-V1",
],
"image_models": [],
"multimodal_models": [],
"reranking_models": [],
"speech2text_models": [],
"tts_models": [],
},

View File

@ -74,7 +74,9 @@ def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]:
provider_dict["llm_models"] = []
provider_dict["embed_models"] = []
provider_dict["image_models"] = []
provider_dict["multimodal_models"] = []
provider_dict["reranking_models"] = []
provider_dict["speech2text_models"] = []
provider_dict["tts_models"] = []
supported_model_str_types = [model_type.to_origin_model_type() for model_type in
provider.supported_model_types]
@ -93,12 +95,16 @@ def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]:
if cur_model_type:
if model_type == "text-generation":
provider_dict["llm_models"] = cur_model_type
elif model_type == "text-embedding":
elif model_type == "embeddings":
provider_dict["embed_models"] = cur_model_type
elif model_type == "text2img":
provider_dict["image_models"] = cur_model_type
elif model_type == "multimodal":
provider_dict["multimodal_models"] = cur_model_type
elif model_type == "reranking":
provider_dict["reranking_models"] = cur_model_type
elif model_type == "speech2text":
provider_dict["speech2text_models"] = cur_model_type
elif model_type == "tts":
provider_dict["tts_models"] = cur_model_type
else:
logger.warning(f"Unsupported model type: {model_type}")

View File

@ -60,7 +60,7 @@ def get_config_platforms() -> Dict[str, Dict]:
def get_config_models(
model_name: str = None,
model_type: Literal["llm", "embed", "image", "multimodal"] = None,
model_type: Literal["llm", "embed", "image", "reranking","speech2text","tts"] = None,
platform_name: str = None,
) -> Dict[str, Dict]:
'''
@ -88,7 +88,14 @@ def get_config_models(
continue
if model_type is None:
model_types = ["llm_models", "embed_models", "image_models", "multimodal_models"]
model_types = [
"llm_models",
"embed_models",
"image_models",
"reranking_models",
"speech2text_models",
"tts_models",
]
else:
model_types = [f"{model_type}_models"]

View File

@ -5,6 +5,7 @@ from typing import List
import pytest
from omegaconf import OmegaConf
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import (
@ -47,3 +48,30 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
for model_type in ModelType.__members__.values():
models_by_model_type = ProvidersWrapper(
provider_manager=provider_manager
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")

View File

@ -4,6 +4,7 @@ import logging
import pytest
from omegaconf import OmegaConf
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
@ -43,3 +44,30 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
for model_type in ModelType.__members__.values():
models_by_model_type = ProvidersWrapper(
provider_manager=provider_manager
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")

View File

@ -4,6 +4,7 @@ import logging
import pytest
from omegaconf import OmegaConf
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
@ -43,3 +44,28 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
for model_type in ModelType.__members__.values():
models_by_model_type = ProvidersWrapper(
provider_manager=provider_manager
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")

View File

@ -1,5 +1,3 @@
xinference:
model_credential:
- model: 'chatglm3-6b'
@ -7,3 +5,8 @@ xinference:
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'
- model: 'bge-m3'
model_type: 'embeddings'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'bge-m3'

View File

@ -4,6 +4,7 @@ import logging
import pytest
from omegaconf import OmegaConf
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
@ -43,3 +44,28 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
for model_type in ModelType.__members__.values():
models_by_model_type = ProvidersWrapper(
provider_manager=provider_manager
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")

View File

@ -4,6 +4,7 @@ import logging
import pytest
from omegaconf import OmegaConf
from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
@ -43,3 +44,28 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non
llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models())
logger.info(f"predefined_models: {llm_models}")
def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None:
logging.config.dictConfig(logging_conf) # type: ignore
# 读取配置文件
cfg = OmegaConf.load(providers_file)
# 转换配置文件
(
provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict,
) = _to_custom_provide_configuration(cfg)
# 创建模型管理器
provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
)
for model_type in ModelType.__members__.values():
models_by_model_type = ProvidersWrapper(
provider_manager=provider_manager
).get_models_by_model_type(model_type=model_type.to_origin_model_type())
print(f"{model_type.to_origin_model_type()}:{models_by_model_type}")