diff --git a/chatchat-server/chatchat/configs/_model_config.py b/chatchat-server/chatchat/configs/_model_config.py index 0ed1950b..856981e9 100644 --- a/chatchat-server/chatchat/configs/_model_config.py +++ b/chatchat-server/chatchat/configs/_model_config.py @@ -69,10 +69,7 @@ LLM_MODEL_CONFIG = { "sd-turbo": { "size": "256*256", } - }, - "multimodal_model": { - "qwen-vl": {} - }, + } } # 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力,启动后下面变量会自动增加相应的平台 @@ -116,7 +113,9 @@ MODEL_PLATFORMS = [ "Embedding-V1", ], "image_models": [], - "multimodal_models": [], + "reranking_models": [], + "speech2text_models": [], + "tts_models": [], }, diff --git a/chatchat-server/chatchat/model_loaders/init_server.py b/chatchat-server/chatchat/model_loaders/init_server.py index 78e50965..2441b8de 100644 --- a/chatchat-server/chatchat/model_loaders/init_server.py +++ b/chatchat-server/chatchat/model_loaders/init_server.py @@ -74,7 +74,9 @@ def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]: provider_dict["llm_models"] = [] provider_dict["embed_models"] = [] provider_dict["image_models"] = [] - provider_dict["multimodal_models"] = [] + provider_dict["reranking_models"] = [] + provider_dict["speech2text_models"] = [] + provider_dict["tts_models"] = [] supported_model_str_types = [model_type.to_origin_model_type() for model_type in provider.supported_model_types] @@ -93,12 +95,16 @@ def init_provider_platforms(provider_manager: ProviderManager)-> List[Dict]: if cur_model_type: if model_type == "text-generation": provider_dict["llm_models"] = cur_model_type - elif model_type == "text-embedding": + elif model_type == "embeddings": provider_dict["embed_models"] = cur_model_type elif model_type == "text2img": provider_dict["image_models"] = cur_model_type - elif model_type == "multimodal": - provider_dict["multimodal_models"] = cur_model_type + elif model_type == "reranking": + provider_dict["reranking_models"] = cur_model_type + elif model_type == "speech2text": + provider_dict["speech2text_models"] = cur_model_type + elif model_type == "tts": + provider_dict["tts_models"] = cur_model_type else: logger.warning(f"Unsupported model type: {model_type}") diff --git a/chatchat-server/chatchat/server/utils.py b/chatchat-server/chatchat/server/utils.py index ac441157..784e94d6 100644 --- a/chatchat-server/chatchat/server/utils.py +++ b/chatchat-server/chatchat/server/utils.py @@ -60,7 +60,7 @@ def get_config_platforms() -> Dict[str, Dict]: def get_config_models( model_name: str = None, - model_type: Literal["llm", "embed", "image", "multimodal"] = None, + model_type: Literal["llm", "embed", "image", "reranking","speech2text","tts"] = None, platform_name: str = None, ) -> Dict[str, Dict]: ''' @@ -88,7 +88,14 @@ def get_config_models( continue if model_type is None: - model_types = ["llm_models", "embed_models", "image_models", "multimodal_models"] + model_types = [ + "llm_models", + "embed_models", + "image_models", + "reranking_models", + "speech2text_models", + "tts_models", + ] else: model_types = [f"{model_type}_models"] diff --git a/model-providers/tests/unit_tests/deepseek/test_deepseek_provider_manager_models.py b/model-providers/tests/unit_tests/deepseek/test_deepseek_provider_manager_models.py index 77569b4b..3d39aeb1 100644 --- a/model-providers/tests/unit_tests/deepseek/test_deepseek_provider_manager_models.py +++ b/model-providers/tests/unit_tests/deepseek/test_deepseek_provider_manager_models.py @@ -5,6 +5,7 @@ from typing import List import pytest from omegaconf import OmegaConf +from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration from model_providers.core.model_manager import ModelManager from model_providers.core.model_runtime.entities.model_entities import ( @@ -47,3 +48,30 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}") + + + +def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None: + logging.config.dictConfig(logging_conf) # type: ignore + # 读取配置文件 + cfg = OmegaConf.load(providers_file) + # 转换配置文件 + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) + # 创建模型管理器 + provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) + + + for model_type in ModelType.__members__.values(): + models_by_model_type = ProvidersWrapper( + provider_manager=provider_manager + ).get_models_by_model_type(model_type=model_type.to_origin_model_type()) + + print(f"{model_type.to_origin_model_type()}:{models_by_model_type}") + + diff --git a/model-providers/tests/unit_tests/ollama/test_ollama_provider_manager_models.py b/model-providers/tests/unit_tests/ollama/test_ollama_provider_manager_models.py index eaaa14fe..d78dddd1 100644 --- a/model-providers/tests/unit_tests/ollama/test_ollama_provider_manager_models.py +++ b/model-providers/tests/unit_tests/ollama/test_ollama_provider_manager_models.py @@ -4,6 +4,7 @@ import logging import pytest from omegaconf import OmegaConf +from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration from model_providers.core.model_manager import ModelManager from model_providers.core.model_runtime.entities.model_entities import ModelType @@ -43,3 +44,30 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}") + + + +def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None: + logging.config.dictConfig(logging_conf) # type: ignore + # 读取配置文件 + cfg = OmegaConf.load(providers_file) + # 转换配置文件 + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) + # 创建模型管理器 + provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) + + + for model_type in ModelType.__members__.values(): + models_by_model_type = ProvidersWrapper( + provider_manager=provider_manager + ).get_models_by_model_type(model_type=model_type.to_origin_model_type()) + + print(f"{model_type.to_origin_model_type()}:{models_by_model_type}") + + diff --git a/model-providers/tests/unit_tests/openai/test_openai_provider_manager_models.py b/model-providers/tests/unit_tests/openai/test_openai_provider_manager_models.py index 21707821..8e7b3c0f 100644 --- a/model-providers/tests/unit_tests/openai/test_openai_provider_manager_models.py +++ b/model-providers/tests/unit_tests/openai/test_openai_provider_manager_models.py @@ -4,6 +4,7 @@ import logging import pytest from omegaconf import OmegaConf +from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration from model_providers.core.model_manager import ModelManager from model_providers.core.model_runtime.entities.model_entities import ModelType @@ -43,3 +44,28 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}") + + +def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None: + logging.config.dictConfig(logging_conf) # type: ignore + # 读取配置文件 + cfg = OmegaConf.load(providers_file) + # 转换配置文件 + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) + # 创建模型管理器 + provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) + + + for model_type in ModelType.__members__.values(): + models_by_model_type = ProvidersWrapper( + provider_manager=provider_manager + ).get_models_by_model_type(model_type=model_type.to_origin_model_type()) + + print(f"{model_type.to_origin_model_type()}:{models_by_model_type}") + diff --git a/model-providers/tests/unit_tests/xinference/model_providers.yaml b/model-providers/tests/unit_tests/xinference/model_providers.yaml index 95209dd4..03eb750a 100644 --- a/model-providers/tests/unit_tests/xinference/model_providers.yaml +++ b/model-providers/tests/unit_tests/xinference/model_providers.yaml @@ -1,5 +1,3 @@ - - xinference: model_credential: - model: 'chatglm3-6b' @@ -7,3 +5,8 @@ xinference: model_credentials: server_url: 'http://127.0.0.1:9997/' model_uid: 'chatglm3-6b' + - model: 'bge-m3' + model_type: 'embeddings' + model_credentials: + server_url: 'http://127.0.0.1:9997/' + model_uid: 'bge-m3' \ No newline at end of file diff --git a/model-providers/tests/unit_tests/xinference/test_xinference_provider_manager_models.py b/model-providers/tests/unit_tests/xinference/test_xinference_provider_manager_models.py index 88b98ed4..173ecf51 100644 --- a/model-providers/tests/unit_tests/xinference/test_xinference_provider_manager_models.py +++ b/model-providers/tests/unit_tests/xinference/test_xinference_provider_manager_models.py @@ -4,6 +4,7 @@ import logging import pytest from omegaconf import OmegaConf +from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration from model_providers.core.model_manager import ModelManager from model_providers.core.model_runtime.entities.model_entities import ModelType @@ -43,3 +44,28 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}") + + +def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None: + logging.config.dictConfig(logging_conf) # type: ignore + # 读取配置文件 + cfg = OmegaConf.load(providers_file) + # 转换配置文件 + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) + # 创建模型管理器 + provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) + + + for model_type in ModelType.__members__.values(): + models_by_model_type = ProvidersWrapper( + provider_manager=provider_manager + ).get_models_by_model_type(model_type=model_type.to_origin_model_type()) + + print(f"{model_type.to_origin_model_type()}:{models_by_model_type}") + diff --git a/model-providers/tests/unit_tests/zhipuai/test_zhipuai_provider_manager_models.py b/model-providers/tests/unit_tests/zhipuai/test_zhipuai_provider_manager_models.py index 38999168..b2311ccf 100644 --- a/model-providers/tests/unit_tests/zhipuai/test_zhipuai_provider_manager_models.py +++ b/model-providers/tests/unit_tests/zhipuai/test_zhipuai_provider_manager_models.py @@ -4,6 +4,7 @@ import logging import pytest from omegaconf import OmegaConf +from model_providers.core.bootstrap.providers_wapper import ProvidersWrapper from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration from model_providers.core.model_manager import ModelManager from model_providers.core.model_runtime.entities.model_entities import ModelType @@ -43,3 +44,28 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non llm_models.extend(provider_model_bundle_llm.model_type_instance.predefined_models()) logger.info(f"predefined_models: {llm_models}") + + +def test_provider_wrapper_models(logging_conf: dict, providers_file: str) -> None: + logging.config.dictConfig(logging_conf) # type: ignore + # 读取配置文件 + cfg = OmegaConf.load(providers_file) + # 转换配置文件 + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) + # 创建模型管理器 + provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) + + + for model_type in ModelType.__members__.values(): + models_by_model_type = ProvidersWrapper( + provider_manager=provider_manager + ).get_models_by_model_type(model_type=model_type.to_origin_model_type()) + + print(f"{model_type.to_origin_model_type()}:{models_by_model_type}") +