provider_configuration.py

查询所有的平台信息,包含计费策略和配置schema_validators(参数必填信息校验规则)
/workspaces/current/model-providers
查询平台模型分类的详细默认信息,包含了模型类型,模型参数,模型状态
workspaces/current/models/model-types/{model_type}
This commit is contained in:
glide-the 2024-04-01 20:09:12 +08:00
parent bfcf2775f5
commit a1fe8d714f
8 changed files with 474 additions and 26 deletions

View File

@ -0,0 +1,183 @@
from enum import Enum
from typing import List, Literal, Optional
from pydantic import BaseModel
from model_providers.core.entities.model_entities import (
ModelStatus,
ModelWithProviderEntity,
)
from model_providers.core.entities.provider_entities import (
ProviderQuotaType,
ProviderType,
QuotaConfiguration,
)
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import (
ModelType,
ProviderModel,
)
from model_providers.core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
ModelCredentialSchema,
ProviderCredentialSchema,
ProviderHelpEntity,
SimpleProviderEntity,
)
class CustomConfigurationStatus(Enum):
"""
Enum class for custom configuration status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
class CustomConfigurationResponse(BaseModel):
"""
Model class for provider custom configuration response.
"""
status: CustomConfigurationStatus
class SystemConfigurationResponse(BaseModel):
"""
Model class for provider system configuration response.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
class ProviderResponse(BaseModel):
"""
Model class for provider response.
"""
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
def __init__(self, **data) -> None:
super().__init__(**data)
#
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class ProviderListResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ProviderResponse] = []
class ModelResponse(ProviderModel):
"""
Model class for model response.
"""
status: ModelStatus
class ProviderWithModelsResponse(BaseModel):
"""
Model class for provider with models response.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus
models: list[ModelResponse]
def __init__(self, **data) -> None:
super().__init__(**data)
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class ProviderModelTypeResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ProviderWithModelsResponse] = []
class SimpleProviderEntityResponse(SimpleProviderEntity):
"""
Simple provider entity response.
"""
def __init__(self, **data) -> None:
super().__init__(**data)
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class DefaultModelResponse(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: SimpleProviderEntityResponse
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
"""
Model with provider entity.
"""
provider: SimpleProviderEntityResponse
def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.dict())

View File

@ -24,6 +24,10 @@ from sse_starlette import EventSourceResponse
from uvicorn import Config, Server
from model_providers.bootstrap_web.common import create_stream_chunk
from model_providers.bootstrap_web.entities.model_provider_entities import (
ProviderListResponse,
ProviderModelTypeResponse,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage,
@ -301,6 +305,18 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
allow_headers=["*"],
)
self._router.add_api_route(
"/workspaces/current/model-providers",
self.workspaces_model_providers,
response_model=ProviderListResponse,
methods=["GET"],
)
self._router.add_api_route(
"/workspaces/current/models/model-types/{model_type}",
self.workspaces_model_types,
response_model=ProviderModelTypeResponse,
methods=["GET"],
)
self._router.add_api_route(
"/{provider}/v1/models",
self.list_models,
@ -345,6 +361,14 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
if started_event is not None:
started_event.set()
async def workspaces_model_providers(self, request: Request):
provider_list = self.get_provider_list(model_type=request.get("model_type"))
return ProviderListResponse(data=provider_list)
async def workspaces_model_types(self, model_type: str, request: Request):
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
return ProviderModelTypeResponse(data=models_by_model_type)
async def list_models(self, provider: str, request: Request):
logger.info(f"Received list_models request for provider: {provider}")
# 返回ModelType所有的枚举

View File

@ -1,13 +1,25 @@
from abc import abstractmethod
from collections import deque
from typing import List, Optional
from fastapi import Request
from model_providers.bootstrap_web.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
ModelResponse,
ProviderResponse,
ProviderWithModelsResponse,
SystemConfigurationResponse,
)
from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionRequest,
EmbeddingsRequest,
)
from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.entities.provider_entities import ProviderType
from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
class Bootstrap:
@ -18,9 +30,150 @@ class Bootstrap:
"""任务队列"""
_QUEUE: deque = deque()
_provider_manager: ModelManager
def __init__(self):
self._version = "v0.0.1"
@property
def provider_manager(self) -> ModelManager:
return self._provider_manager
@provider_manager.setter
def provider_manager(self, provider_manager: ModelManager):
self._provider_manager = provider_manager
def get_provider_list(
self, model_type: Optional[str] = None
) -> List[ProviderResponse]:
"""
get provider list.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
if (
model_type_entity
not in provider_configuration.provider.supported_model_types
):
continue
provider_response = ProviderResponse(
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=ProviderType.value_of("custom"),
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(enabled=False),
)
provider_responses.append(provider_response)
return provider_responses
def get_models_by_model_type(
self, model_type: str
) -> List[ProviderWithModelsResponse]:
"""
get models by model type.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
# Get provider available models
models = provider_configurations.get_models(
model_type=ModelType.value_of(model_type)
)
# Group models by provider
provider_models = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
if model.deprecated:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
first_model = models[0]
has_active_models = any(
[model.status == ModelStatus.ACTIVE for model in models]
)
providers_with_models.append(
ProviderWithModelsResponse(
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE
if has_active_models
else CustomConfigurationStatus.NO_CONFIGURE,
models=[
ModelResponse(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
)
for model in models
],
)
)
return providers_with_models
@classmethod
@abstractmethod
def from_config(cls, cfg=None):
@ -44,19 +197,9 @@ class Bootstrap:
class OpenAIBootstrapBaseWeb(Bootstrap):
_provider_manager: ModelManager
def __init__(self):
super().__init__()
@property
def provider_manager(self) -> ModelManager:
return self._provider_manager
@provider_manager.setter
def provider_manager(self, provider_manager: ModelManager):
self._provider_manager = provider_manager
@abstractmethod
async def list_models(self, provider: str, request: Request):
pass

View File

@ -66,6 +66,16 @@ class ProviderConfiguration(BaseModel):
else:
return None
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
return (
self.custom_configuration.provider is not None
or len(self.custom_configuration.models) > 0
)
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Get custom credentials.

View File

@ -6,12 +6,82 @@ from pydantic import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
class ProviderType(Enum):
CUSTOM = "custom"
SYSTEM = "system"
@staticmethod
def value_of(value):
for member in ProviderType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderQuotaType(Enum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value):
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum):
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class SystemConfigurationStatus(Enum):
"""
Enum class for system configuration status.
"""
ACTIVE = "active"
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
class RestrictModel(BaseModel):
model: str
base_model_name: Optional[str] = None
model_type: ModelType
class QuotaConfiguration(BaseModel):
"""
Model class for provider quota configuration.
"""
quota_type: ProviderQuotaType
quota_unit: QuotaUnit
quota_limit: int
quota_used: int
is_valid: bool
restrict_models: list[RestrictModel] = []
class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
credentials: Optional[dict] = None
class CustomProviderConfiguration(BaseModel):
"""
Model class for provider custom configuration.

View File

@ -1,7 +1,7 @@
import importlib
import logging
import os
from typing import Optional
from typing import Optional, Union
from pydantic import BaseModel
@ -49,7 +49,9 @@ class ModelProviderFactory:
if init_cache:
self.get_providers()
def get_providers(self, provider_name: str = "") -> list[ProviderEntity]:
def get_providers(
self, provider_name: Union[str, set] = ""
) -> list[ProviderEntity]:
"""
Get all providers
:return: list of providers
@ -60,20 +62,36 @@ class ModelProviderFactory:
# traverse all model_provider_extensions
providers = []
for name, model_provider_extension in model_provider_extensions.items():
if provider_name in (name, ""):
# get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
if isinstance(provider_name, str):
if provider_name in (name, ""):
# get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
# get provider schema
provider_schema = model_provider_instance.get_provider_schema()
# get provider schema
provider_schema = model_provider_instance.get_provider_schema()
for model_type in provider_schema.supported_model_types:
# get predefined models for given model type
models = model_provider_instance.models(model_type)
if models:
provider_schema.models.extend(models)
for model_type in provider_schema.supported_model_types:
# get predefined models for given model type
models = model_provider_instance.models(model_type)
if models:
provider_schema.models.extend(models)
providers.append(provider_schema)
providers.append(provider_schema)
elif isinstance(provider_name, set):
if name in provider_name:
# get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
# get provider schema
provider_schema = model_provider_instance.get_provider_schema()
for model_type in provider_schema.supported_model_types:
# get predefined models for given model type
models = model_provider_instance.models(model_type)
if models:
provider_schema.models.extend(models)
providers.append(provider_schema)
# return providers
return providers

View File

@ -1,7 +1,7 @@
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Optional
from typing import Optional, Union
from sqlalchemy.exc import IntegrityError
@ -45,7 +45,7 @@ class ProviderManager:
provider_name_to_provider_model_records_dict
)
def get_configurations(self, provider: str) -> ProviderConfigurations:
def get_configurations(self, provider: Union[str, set]) -> ProviderConfigurations:
"""
Get model provider configurations.
@ -155,7 +155,7 @@ class ProviderManager:
default_model = {}
# Get provider configurations
provider_configurations = self.get_configurations()
provider_configurations = self.get_configurations(provider="openai")
# get available models from provider_configurations
available_models = provider_configurations.get_models(