mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-27 01:03:35 +08:00
49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
import asyncio
|
|
import logging
|
|
|
|
import pytest
|
|
from omegaconf import OmegaConf
|
|
|
|
from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration
|
|
from model_providers.core.model_manager import ModelManager
|
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
|
from model_providers.core.provider_manager import ProviderManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None:
|
|
logging.config.dictConfig(logging_conf) # type: ignore
|
|
# 读取配置文件
|
|
cfg = OmegaConf.load(
|
|
providers_file
|
|
)
|
|
# 转换配置文件
|
|
(
|
|
provider_name_to_provider_records_dict,
|
|
provider_name_to_provider_model_records_dict,
|
|
) = _to_custom_provide_configuration(cfg)
|
|
# 创建模型管理器
|
|
provider_manager = ProviderManager(
|
|
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
|
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict,
|
|
)
|
|
|
|
provider_model_bundle_llm = provider_manager.get_provider_model_bundle(
|
|
provider="xinference", model_type=ModelType.LLM
|
|
)
|
|
llm_models: List[AIModelEntity] = []
|
|
for model in provider_model_bundle_llm.configuration.custom_configuration.models:
|
|
|
|
llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema(
|
|
model=model.model,
|
|
credentials=model.credentials,
|
|
))
|
|
|
|
# 获取预定义模型
|
|
llm_models.extend(
|
|
provider_model_bundle_llm.model_type_instance.predefined_models()
|
|
)
|
|
|
|
logger.info(f"predefined_models: {llm_models}")
|