diff --git a/model-providers/model_providers/__init__.py b/model-providers/model_providers/__init__.py index ebbdf71e..1be0f328 100644 --- a/model-providers/model_providers/__init__.py +++ b/model-providers/model_providers/__init__.py @@ -1,9 +1,11 @@ -from model_providers.bootstrap_web.openai_bootstrap_web import RESTFulOpenAIBootstrapBaseWeb +from omegaconf import DictConfig, OmegaConf + +from model_providers.bootstrap_web.openai_bootstrap_web import ( + RESTFulOpenAIBootstrapBaseWeb, +) from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.model_manager import ModelManager -from omegaconf import OmegaConf, DictConfig - def _to_custom_provide_configuration(cfg: DictConfig): """ @@ -34,8 +36,8 @@ def _to_custom_provide_configuration(cfg: DictConfig): provider_name_to_provider_model_records_dict = {} for key, item in cfg.items(): - model_credential = item.get('model_credential') - provider_credential = item.get('provider_credential') + model_credential = item.get("model_credential") + provider_credential = item.get("provider_credential") # 转换omegaconf对象为基本属性 if model_credential: model_credential = OmegaConf.to_container(model_credential) @@ -44,13 +46,17 @@ def _to_custom_provide_configuration(cfg: DictConfig): provider_credential = OmegaConf.to_container(provider_credential) provider_name_to_provider_records_dict[key] = provider_credential - return provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict + return ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) class BootstrapWebBuilder: """ 创建一个模型实例创建工具 """ + _model_providers_cfg_path: str _host: str _port: int @@ -68,20 +74,26 @@ class BootstrapWebBuilder: return self def build(self) -> OpenAIBootstrapBaseWeb: - assert self._model_providers_cfg_path is not None and self._host is not None and self._port is not None + assert ( + self._model_providers_cfg_path is not None + and self._host is not None + and self._port is not None + ) # 读取配置文件 cfg = OmegaConf.load(self._model_providers_cfg_path) # 转换配置文件 - provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( - cfg) + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) # 创建模型管理器 - provider_manager = ModelManager(provider_name_to_provider_records_dict, - provider_name_to_provider_model_records_dict) + provider_manager = ModelManager( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) # 创建web服务 - restful = RESTFulOpenAIBootstrapBaseWeb.from_config(cfg={ - "host": self._host, - "port": self._port - - }) + restful = RESTFulOpenAIBootstrapBaseWeb.from_config( + cfg={"host": self._host, "port": self._port} + ) restful.provider_manager = provider_manager return restful diff --git a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py index e47ed678..6e692cee 100644 --- a/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py +++ b/model-providers/model_providers/bootstrap_web/openai_bootstrap_web.py @@ -21,13 +21,17 @@ from model_providers.core.bootstrap.openai_protocol import ( EmbeddingsRequest, EmbeddingsResponse, FunctionAvailable, - ModelList, ModelCard, + ModelCard, + ModelList, ) -from model_providers.core.model_manager import ModelManager, ModelInstance +from model_providers.core.model_manager import ModelInstance, ModelManager from model_providers.core.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from model_providers.core.model_runtime.entities.model_entities import ModelType, AIModelEntity +from model_providers.core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelType, +) from model_providers.core.utils.generic import dictify, jsonify logger = logging.getLogger(__name__) @@ -115,24 +119,31 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): llm_models: list[AIModelEntity] = [] for model_type in ModelType.__members__.values(): try: - provider_model_bundle = self._provider_manager.provider_manager.get_provider_model_bundle( - provider=provider, model_type=model_type + provider_model_bundle = ( + self._provider_manager.provider_manager.get_provider_model_bundle( + provider=provider, model_type=model_type + ) + ) + llm_models.extend( + provider_model_bundle.model_type_instance.predefined_models() ) - llm_models.extend(provider_model_bundle.model_type_instance.predefined_models()) except Exception as e: - logger.error(f"Error while fetching models for provider: {provider}, model_type: {model_type}") + logger.error( + f"Error while fetching models for provider: {provider}, model_type: {model_type}" + ) logger.error(e) # models list[AIModelEntity]转换称List[ModelCard] - models_list = [ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) for model in llm_models] + models_list = [ + ModelCard(id=model.model, object=model.model_type.to_origin_model_type()) + for model in llm_models + ] - return ModelList( - data=models_list - ) + return ModelList(data=models_list) async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): logger.info( f"Received create_embeddings request: {pprint.pformat(embeddings_request.dict())}" @@ -142,7 +153,7 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): return EmbeddingsResponse(**dictify(response)) async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): logger.info( f"Received chat completion request: {pprint.pformat(chat_request.dict())}" @@ -180,9 +191,9 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb): def run( - cfg: Dict, - logging_conf: Optional[dict] = None, - started_event: mp.Event = None, + cfg: Dict, + logging_conf: Optional[dict] = None, + started_event: mp.Event = None, ): logging.config.dictConfig(logging_conf) # type: ignore try: diff --git a/model-providers/model_providers/core/bootstrap/base.py b/model-providers/model_providers/core/bootstrap/base.py index 795eb565..8572323e 100644 --- a/model-providers/model_providers/core/bootstrap/base.py +++ b/model-providers/model_providers/core/bootstrap/base.py @@ -3,7 +3,10 @@ from collections import deque from fastapi import Request -from model_providers.core.bootstrap.openai_protocol import ChatCompletionRequest, EmbeddingsRequest +from model_providers.core.bootstrap.openai_protocol import ( + ChatCompletionRequest, + EmbeddingsRequest, +) from model_providers.core.model_manager import ModelManager @@ -60,12 +63,12 @@ class OpenAIBootstrapBaseWeb(Bootstrap): @abstractmethod async def create_embeddings( - self, provider: str, request: Request, embeddings_request: EmbeddingsRequest + self, provider: str, request: Request, embeddings_request: EmbeddingsRequest ): pass @abstractmethod async def create_chat_completion( - self, provider: str, request: Request, chat_request: ChatCompletionRequest + self, provider: str, request: Request, chat_request: ChatCompletionRequest ): pass diff --git a/model-providers/model_providers/core/bootstrap/openai_protocol.py b/model-providers/model_providers/core/bootstrap/openai_protocol.py index 558dc95b..2753ad5d 100644 --- a/model-providers/model_providers/core/bootstrap/openai_protocol.py +++ b/model-providers/model_providers/core/bootstrap/openai_protocol.py @@ -22,7 +22,15 @@ class Finish(str, Enum): class ModelCard(BaseModel): id: str - object: Literal["text-generation","embeddings","reranking", "speech2text", "moderation", "tts", "text2img"] = "llm" + object: Literal[ + "text-generation", + "embeddings", + "reranking", + "speech2text", + "moderation", + "tts", + "text2img", + ] = "llm" created: int = Field(default_factory=lambda: int(time.time())) owned_by: Literal["owner"] = "owner" diff --git a/model-providers/model_providers/core/model_runtime/entities/model_entities.py b/model-providers/model_providers/core/model_runtime/entities/model_entities.py index ae5938a4..0fea8c1d 100644 --- a/model-providers/model_providers/core/model_runtime/entities/model_entities.py +++ b/model-providers/model_providers/core/model_runtime/entities/model_entities.py @@ -1,6 +1,6 @@ from decimal import Decimal from enum import Enum -from typing import Any, Optional, List +from typing import Any, List, Optional from pydantic import BaseModel @@ -74,7 +74,6 @@ class ModelType(Enum): raise ValueError(f"invalid model type {self}") - class FetchFrom(Enum): """ Enum class for fetch from. diff --git a/model-providers/model_providers/core/provider_manager.py b/model-providers/model_providers/core/provider_manager.py index 5cb5ef43..5a99e314 100644 --- a/model-providers/model_providers/core/provider_manager.py +++ b/model-providers/model_providers/core/provider_manager.py @@ -253,9 +253,11 @@ class ProviderManager: custom_model_configurations.append( CustomModelConfiguration( - model=provider_model_record.get('model'), - model_type=ModelType.value_of(provider_model_record.get('model_type')), - credentials=provider_model_credentials + model=provider_model_record.get("model"), + model_type=ModelType.value_of( + provider_model_record.get("model_type") + ), + credentials=provider_model_credentials, ) ) diff --git a/model-providers/model_providers/core/utils/utils.py b/model-providers/model_providers/core/utils/utils.py index 2bc3fde8..dbfd7c6b 100644 --- a/model-providers/model_providers/core/utils/utils.py +++ b/model-providers/model_providers/core/utils/utils.py @@ -1,7 +1,6 @@ import logging - -import time import os +import time logger = logging.getLogger(__name__) @@ -26,7 +25,7 @@ def get_log_file(log_path: str, sub_dir: str): def get_config_dict( - log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int + log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int ) -> dict: # for windows, the path should be a raw string. log_file_path = ( diff --git a/model-providers/tests/server_unit_test/conftest.py b/model-providers/tests/server_unit_test/conftest.py index dc1af083..eea02a65 100644 --- a/model-providers/tests/server_unit_test/conftest.py +++ b/model-providers/tests/server_unit_test/conftest.py @@ -1,11 +1,16 @@ """Configuration for unit tests.""" -from importlib import util -from typing import Dict, Sequence, List import logging +from importlib import util +from typing import Dict, List, Sequence + import pytest from pytest import Config, Function, Parser -from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file +from model_providers.core.utils.utils import ( + get_config_dict, + get_log_file, + get_timestamp_ms, +) def pytest_addoption(parser: Parser) -> None: diff --git a/model-providers/tests/server_unit_test/test_init_server.py b/model-providers/tests/server_unit_test/test_init_server.py index d4145c77..96210b89 100644 --- a/model-providers/tests/server_unit_test/test_init_server.py +++ b/model-providers/tests/server_unit_test/test_init_server.py @@ -1,21 +1,26 @@ -from model_providers import BootstrapWebBuilder -import logging import asyncio +import logging import pytest + +from model_providers import BootstrapWebBuilder + logger = logging.getLogger(__name__) @pytest.mark.requires("fastapi") def test_init_server(logging_conf: dict) -> None: try: - boot = BootstrapWebBuilder() \ + boot = ( + BootstrapWebBuilder() .model_providers_cfg_path( - model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" - "/model_providers.yaml") \ - .host(host="127.0.0.1") \ - .port(port=20000) \ + model_providers_cfg_path="/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" + "/model_providers.yaml" + ) + .host(host="127.0.0.1") + .port(port=20000) .build() + ) boot.set_app_event(started_event=None) boot.serve(logging_conf=logging_conf) diff --git a/model-providers/tests/unit_test/conftest.py b/model-providers/tests/unit_test/conftest.py index dc1af083..eea02a65 100644 --- a/model-providers/tests/unit_test/conftest.py +++ b/model-providers/tests/unit_test/conftest.py @@ -1,11 +1,16 @@ """Configuration for unit tests.""" -from importlib import util -from typing import Dict, Sequence, List import logging +from importlib import util +from typing import Dict, List, Sequence + import pytest from pytest import Config, Function, Parser -from model_providers.core.utils.utils import get_config_dict, get_timestamp_ms, get_log_file +from model_providers.core.utils.utils import ( + get_config_dict, + get_log_file, + get_timestamp_ms, +) def pytest_addoption(parser: Parser) -> None: diff --git a/model-providers/tests/unit_test/test_provider_manager_models.py b/model-providers/tests/unit_test/test_provider_manager_models.py index c7afd8dc..023c48ec 100644 --- a/model-providers/tests/unit_test/test_provider_manager_models.py +++ b/model-providers/tests/unit_test/test_provider_manager_models.py @@ -1,11 +1,10 @@ +import asyncio +import logging + +import pytest from omegaconf import OmegaConf from model_providers import BootstrapWebBuilder, _to_custom_provide_configuration -import logging -import asyncio - -import pytest - from model_providers.core.model_manager import ModelManager from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.provider_manager import ProviderManager @@ -14,14 +13,17 @@ logger = logging.getLogger(__name__) def test_provider_manager_models(logging_conf: dict) -> None: - logging.config.dictConfig(logging_conf) # type: ignore # 读取配置文件 - cfg = OmegaConf.load("/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" - "/model_providers.yaml") + cfg = OmegaConf.load( + "/media/gpt4-pdf-chatbot-langchain/langchain-ChatGLM/model-providers" + "/model_providers.yaml" + ) # 转换配置文件 - provider_name_to_provider_records_dict, provider_name_to_provider_model_records_dict = _to_custom_provide_configuration( - cfg) + ( + provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict, + ) = _to_custom_provide_configuration(cfg) # 创建模型管理器 provider_manager = ProviderManager( provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, @@ -34,6 +36,8 @@ def test_provider_manager_models(logging_conf: dict) -> None: provider_model_bundle_emb = provider_manager.get_provider_model_bundle( provider="openai", model_type=ModelType.TEXT_EMBEDDING ) - predefined_models = provider_model_bundle_emb.model_type_instance.predefined_models() + predefined_models = ( + provider_model_bundle_emb.model_type_instance.predefined_models() + ) logger.info(f"predefined_models: {predefined_models}")