diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index a0661ec6..593fe394 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -185,9 +185,17 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): provider=provider, model_type=model_type ) ) + # 获取预定义模型 llm_models.extend( provider_model_bundle.model_type_instance.predefined_models() ) + # 获取自定义模型 + for model in provider_model_bundle.configuration.custom_configuration.models: + + llm_models.append(provider_model_bundle.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + )) except Exception as e: logger.error( f"Error while fetching models for provider: {provider}, model_type: {model_type}" diff --git a/model-providers/tests/integration_tests/xinference_providers_test/model_providers.yaml b/model-providers/tests/integration_tests/xinference_providers_test/model_providers.yaml new file mode 100644 index 00000000..fc26fa5f --- /dev/null +++ b/model-providers/tests/integration_tests/xinference_providers_test/model_providers.yaml @@ -0,0 +1,15 @@ + + +xinference: + model_credential: + - model: 'chatglm3-6b' + model_type: 'llm' + model_credentials: + server_url: 'http://127.0.0.1:9997/' + model_uid: 'chatglm3-6b' + + - model: 'chatglm31-6b' + model_type: 'llm' + model_credentials: + server_url: 'http://127.0.0.1:9997/' + model_uid: 'chatglm3-6b' diff --git a/model-providers/tests/integration_tests/xinference_providers_test/test_service.py b/model-providers/tests/integration_tests/xinference_providers_test/test_service.py new file mode 100644 index 00000000..831406ff --- /dev/null +++ b/model-providers/tests/integration_tests/xinference_providers_test/test_service.py @@ -0,0 +1,37 @@ +from langchain.chains import LLMChain +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +import pytest +import logging + +logger = logging.getLogger(__name__) + + +@pytest.mark.requires("xinference_client") +def test_llm(init_server: str): + llm = ChatOpenAI( + + model_name="glm-4", + openai_api_key="YOUR_API_KEY", openai_api_base=f"{init_server}/xinference/v1") + template = """Question: {question} + + Answer: Let's think step by step.""" + + prompt = PromptTemplate.from_template(template) + + llm_chain = LLMChain(prompt=prompt, llm=llm) + responses = llm_chain.run("你好") + logger.info("\033[1;32m" + f"llm_chain: {responses}" + "\033[0m") + + +@pytest.mark.requires("xinference-client") +def test_embedding(init_server: str): + embeddings = OpenAIEmbeddings(model="text_embedding", + openai_api_key="YOUR_API_KEY", + openai_api_base=f"{init_server}/xinference/v1") + + text = "你好" + + query_result = embeddings.embed_query(text) + + logger.info("\033[1;32m" + f"embeddings: {query_result}" + "\033[0m") diff --git a/model-providers/tests/unit_tests/deepseek/test_provider_manager_models.py b/model-providers/tests/unit_tests/deepseek/test_provider_manager_models.py index 109fd37b..f28e8ba4 100644 --- a/model-providers/tests/unit_tests/deepseek/test_provider_manager_models.py +++ b/model-providers/tests/unit_tests/deepseek/test_provider_manager_models.py @@ -1,18 +1,19 @@ import asyncio import logging +from typing import List import pytest from omegaconf import OmegaConf from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration from model_providers.core.model_manager import ModelManager -from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity from model_providers.core.provider_manager import ProviderManager logger = logging.getLogger(__name__) -def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None: +def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: logging.config.dictConfig(logging_conf) # type: ignore # 读取配置文件 cfg = OmegaConf.load( @@ -32,8 +33,17 @@ def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) provider_model_bundle_llm = provider_manager.get_provider_model_bundle( provider="deepseek", model_type=ModelType.LLM ) - predefined_models = ( + llm_models: List[AIModelEntity] = [] + for model in provider_model_bundle_llm.configuration.custom_configuration.models: + + llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + )) + + # 获取预定义模型 + llm_models.extend( provider_model_bundle_llm.model_type_instance.predefined_models() ) - logger.info(f"predefined_models: {predefined_models}") + logger.info(f"predefined_models: {llm_models}") diff --git a/model-providers/tests/unit_tests/ollama/test_provider_manager_models.py b/model-providers/tests/unit_tests/ollama/test_provider_manager_models.py index e60d082d..6d307698 100644 --- a/model-providers/tests/unit_tests/ollama/test_provider_manager_models.py +++ b/model-providers/tests/unit_tests/ollama/test_provider_manager_models.py @@ -12,7 +12,7 @@ from model_providers.core.provider_manager import ProviderManager logger = logging.getLogger(__name__) -def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) -> None: +def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: logging.config.dictConfig(logging_conf) # type: ignore # 读取配置文件 cfg = OmegaConf.load( @@ -32,11 +32,18 @@ def test_ollama_provider_manager_models(logging_conf: dict, providers_file: str) provider_model_bundle_llm = provider_manager.get_provider_model_bundle( provider="ollama", model_type=ModelType.LLM ) - provider_model_bundle_emb = provider_manager.get_provider_model_bundle( - provider="ollama", model_type=ModelType.TEXT_EMBEDDING - ) - predefined_models = ( + llm_models: List[AIModelEntity] = [] + for model in provider_model_bundle_llm.configuration.custom_configuration.models: + + llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + )) + + # 获取预定义模型 + llm_models.extend( provider_model_bundle_llm.model_type_instance.predefined_models() ) - logger.info(f"predefined_models: {predefined_models}") + logger.info(f"predefined_models: {llm_models}") + diff --git a/model-providers/tests/unit_tests/openai/model_providers.yaml b/model-providers/tests/unit_tests/openai/model_providers.yaml index 908883c7..4b7234d5 100644 --- a/model-providers/tests/unit_tests/openai/model_providers.yaml +++ b/model-providers/tests/unit_tests/openai/model_providers.yaml @@ -17,17 +17,3 @@ openai: openai_api_key: 'sk-' openai_organization: '' openai_api_base: '' - -xinference: - model_credential: - - model: 'chatglm3-6b' - model_type: 'llm' - model_credentials: - server_url: 'http://127.0.0.1:9997/' - model_uid: 'chatglm3-6b' - - -zhipuai: - - provider_credential: - api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1' \ No newline at end of file diff --git a/model-providers/tests/unit_tests/openai/test_provider_manager_models.py b/model-providers/tests/unit_tests/openai/test_provider_manager_models.py index 9416fb9f..43c0a3d2 100644 --- a/model-providers/tests/unit_tests/openai/test_provider_manager_models.py +++ b/model-providers/tests/unit_tests/openai/test_provider_manager_models.py @@ -32,11 +32,17 @@ def test_provider_manager_models(logging_conf: dict, providers_file: str) -> Non provider_model_bundle_llm = provider_manager.get_provider_model_bundle( provider="openai", model_type=ModelType.LLM ) - provider_model_bundle_emb = provider_manager.get_provider_model_bundle( - provider="openai", model_type=ModelType.TEXT_EMBEDDING - ) - predefined_models = ( - provider_model_bundle_emb.model_type_instance.predefined_models() + llm_models: List[AIModelEntity] = [] + for model in provider_model_bundle_llm.configuration.custom_configuration.models: + + llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + )) + + # 获取预定义模型 + llm_models.extend( + provider_model_bundle_llm.model_type_instance.predefined_models() ) - logger.info(f"predefined_models: {predefined_models}") + logger.info(f"predefined_models: {llm_models}") diff --git a/model-providers/tests/unit_tests/xinference/model_providers.yaml b/model-providers/tests/unit_tests/xinference/model_providers.yaml new file mode 100644 index 00000000..95209dd4 --- /dev/null +++ b/model-providers/tests/unit_tests/xinference/model_providers.yaml @@ -0,0 +1,9 @@ + + +xinference: + model_credential: + - model: 'chatglm3-6b' + model_type: 'llm' + model_credentials: + server_url: 'http://127.0.0.1:9997/' + model_uid: 'chatglm3-6b' diff --git a/model-providers/tests/unit_tests/xinference/test_provider_manager_models.py b/model-providers/tests/unit_tests/xinference/test_provider_manager_models.py new file mode 100644 index 00000000..a8c404ef --- /dev/null +++ b/model-providers/tests/unit_tests/xinference/test_provider_manager_models.py @@ -0,0 +1,48 @@ +import asyncio +import logging + +import pytest +from omegaconf import OmegaConf + +from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration +from model_providers.core.model_manager import ModelManager +from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.provider_manager import ProviderManager + +logger = logging.getLogger(__name__) + + +def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: + logging.config.dictConfig(logging_conf) # type: ignore + # 读取配置文件 + cfg = OmegaConf.load( + providers_file + ) + # 转换配置文件 + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) + # 创建模型管理器 + provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) + + provider_model_bundle_llm = provider_manager.get_provider_model_bundle( + provider="xinference", model_type=ModelType.LLM + ) + llm_models: List[AIModelEntity] = [] + for model in provider_model_bundle_llm.configuration.custom_configuration.models: + + llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + )) + + # 获取预定义模型 + llm_models.extend( + provider_model_bundle_llm.model_type_instance.predefined_models() + ) + + logger.info(f"predefined_models: {llm_models}") diff --git a/model-providers/tests/unit_tests/zhipuai/model_providers.yaml b/model-providers/tests/unit_tests/zhipuai/model_providers.yaml new file mode 100644 index 00000000..4f90212f --- /dev/null +++ b/model-providers/tests/unit_tests/zhipuai/model_providers.yaml @@ -0,0 +1,6 @@ + + +zhipuai: + + provider_credential: + api_key: 'd4fa0690b6dfa205204cae2e12aa6fb6.1' diff --git a/model-providers/tests/unit_tests/zhipuai/test_provider_manager_models.py b/model-providers/tests/unit_tests/zhipuai/test_provider_manager_models.py new file mode 100644 index 00000000..23ba6f62 --- /dev/null +++ b/model-providers/tests/unit_tests/zhipuai/test_provider_manager_models.py @@ -0,0 +1,48 @@ +import asyncio +import logging + +import pytest +from omegaconf import OmegaConf + +from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration +from model_providers.core.model_manager import ModelManager +from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.provider_manager import ProviderManager + +logger = logging.getLogger(__name__) + + +def test_provider_manager_models(logging_conf: dict, providers_file: str) -> None: + logging.config.dictConfig(logging_conf) # type: ignore + # 读取配置文件 + cfg = OmegaConf.load( + providers_file + ) + # 转换配置文件 + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) + # 创建模型管理器 + provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict, + ) + + provider_model_bundle_llm = provider_manager.get_provider_model_bundle( + provider="zhipuai", model_type=ModelType.LLM + ) + llm_models: List[AIModelEntity] = [] + for model in provider_model_bundle_llm.configuration.custom_configuration.models: + + llm_models.append(provider_model_bundle_llm.model_type_instance.get_model_schema( + model=model.model, + credentials=model.credentials, + )) + + # 获取预定义模型 + llm_models.extend( + provider_model_bundle_llm.model_type_instance.predefined_models() + ) + + logger.info(f"predefined_models: {llm_models}")