provider_configuration.py

查询所有的平台信息,包含计费策略和配置schema_validators(参数必填信息校验规则)
/workspaces/current/model-providers
查询平台模型分类的详细默认信息,包含了模型类型,模型参数,模型状态
workspaces/current/models/model-types/{model_type}
This commit is contained in:
glide-the 2024-04-01 20:09:12 +08:00
parent bfcf2775f5
commit a1fe8d714f
8 changed files with 474 additions and 26 deletions

View File

@ -0,0 +1,183 @@
from enum import Enum
from typing import List, Literal, Optional
from pydantic import BaseModel
from model_providers.core.entities.model_entities import (
ModelStatus,
ModelWithProviderEntity,
)
from model_providers.core.entities.provider_entities import (
ProviderQuotaType,
ProviderType,
QuotaConfiguration,
)
from model_providers.core.model_runtime.entities.common_entities import I18nObject
from model_providers.core.model_runtime.entities.model_entities import (
ModelType,
ProviderModel,
)
from model_providers.core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
ModelCredentialSchema,
ProviderCredentialSchema,
ProviderHelpEntity,
SimpleProviderEntity,
)
class CustomConfigurationStatus(Enum):
"""
Enum class for custom configuration status.
"""
ACTIVE = "active"
NO_CONFIGURE = "no-configure"
class CustomConfigurationResponse(BaseModel):
"""
Model class for provider custom configuration response.
"""
status: CustomConfigurationStatus
class SystemConfigurationResponse(BaseModel):
"""
Model class for provider system configuration response.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
class ProviderResponse(BaseModel):
"""
Model class for provider response.
"""
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
def __init__(self, **data) -> None:
super().__init__(**data)
#
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class ProviderListResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ProviderResponse] = []
class ModelResponse(ProviderModel):
"""
Model class for model response.
"""
status: ModelStatus
class ProviderWithModelsResponse(BaseModel):
"""
Model class for provider with models response.
"""
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
status: CustomConfigurationStatus
models: list[ModelResponse]
def __init__(self, **data) -> None:
super().__init__(**data)
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class ProviderModelTypeResponse(BaseModel):
object: Literal["list"] = "list"
data: List[ProviderWithModelsResponse] = []
class SimpleProviderEntityResponse(SimpleProviderEntity):
"""
Simple provider entity response.
"""
def __init__(self, **data) -> None:
super().__init__(**data)
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
# if self.icon_small is not None:
# self.icon_small = I18nObject(
# en_US=f"{url_prefix}/icon_small/en_US",
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
# )
#
# if self.icon_large is not None:
# self.icon_large = I18nObject(
# en_US=f"{url_prefix}/icon_large/en_US",
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
# )
class DefaultModelResponse(BaseModel):
"""
Default model entity.
"""
model: str
model_type: ModelType
provider: SimpleProviderEntityResponse
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
"""
Model with provider entity.
"""
provider: SimpleProviderEntityResponse
def __init__(self, model: ModelWithProviderEntity) -> None:
super().__init__(**model.dict())

View File

@ -24,6 +24,10 @@ from sse_starlette import EventSourceResponse
from uvicorn import Config, Server from uvicorn import Config, Server
from model_providers.bootstrap_web.common import create_stream_chunk from model_providers.bootstrap_web.common import create_stream_chunk
from model_providers.bootstrap_web.entities.model_provider_entities import (
ProviderListResponse,
ProviderModelTypeResponse,
)
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
from model_providers.core.bootstrap.openai_protocol import ( from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionMessage, ChatCompletionMessage,
@ -301,6 +305,18 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
allow_headers=["*"], allow_headers=["*"],
) )
self._router.add_api_route(
"/workspaces/current/model-providers",
self.workspaces_model_providers,
response_model=ProviderListResponse,
methods=["GET"],
)
self._router.add_api_route(
"/workspaces/current/models/model-types/{model_type}",
self.workspaces_model_types,
response_model=ProviderModelTypeResponse,
methods=["GET"],
)
self._router.add_api_route( self._router.add_api_route(
"/{provider}/v1/models", "/{provider}/v1/models",
self.list_models, self.list_models,
@ -345,6 +361,14 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
if started_event is not None: if started_event is not None:
started_event.set() started_event.set()
async def workspaces_model_providers(self, request: Request):
provider_list = self.get_provider_list(model_type=request.get("model_type"))
return ProviderListResponse(data=provider_list)
async def workspaces_model_types(self, model_type: str, request: Request):
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
return ProviderModelTypeResponse(data=models_by_model_type)
async def list_models(self, provider: str, request: Request): async def list_models(self, provider: str, request: Request):
logger.info(f"Received list_models request for provider: {provider}") logger.info(f"Received list_models request for provider: {provider}")
# 返回ModelType所有的枚举 # 返回ModelType所有的枚举

View File

@ -1,13 +1,25 @@
from abc import abstractmethod from abc import abstractmethod
from collections import deque from collections import deque
from typing import List, Optional
from fastapi import Request from fastapi import Request
from model_providers.bootstrap_web.entities.model_provider_entities import (
CustomConfigurationResponse,
CustomConfigurationStatus,
ModelResponse,
ProviderResponse,
ProviderWithModelsResponse,
SystemConfigurationResponse,
)
from model_providers.core.bootstrap.openai_protocol import ( from model_providers.core.bootstrap.openai_protocol import (
ChatCompletionRequest, ChatCompletionRequest,
EmbeddingsRequest, EmbeddingsRequest,
) )
from model_providers.core.entities.model_entities import ModelStatus
from model_providers.core.entities.provider_entities import ProviderType
from model_providers.core.model_manager import ModelManager from model_providers.core.model_manager import ModelManager
from model_providers.core.model_runtime.entities.model_entities import ModelType
class Bootstrap: class Bootstrap:
@ -18,9 +30,150 @@ class Bootstrap:
"""任务队列""" """任务队列"""
_QUEUE: deque = deque() _QUEUE: deque = deque()
_provider_manager: ModelManager
def __init__(self): def __init__(self):
self._version = "v0.0.1" self._version = "v0.0.1"
@property
def provider_manager(self) -> ModelManager:
return self._provider_manager
@provider_manager.setter
def provider_manager(self, provider_manager: ModelManager):
self._provider_manager = provider_manager
def get_provider_list(
self, model_type: Optional[str] = None
) -> List[ProviderResponse]:
"""
get provider list.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
if (
model_type_entity
not in provider_configuration.provider.supported_model_types
):
continue
provider_response = ProviderResponse(
provider=provider_configuration.provider.provider,
label=provider_configuration.provider.label,
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
configurate_methods=provider_configuration.provider.configurate_methods,
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
model_credential_schema=provider_configuration.provider.model_credential_schema,
preferred_provider_type=ProviderType.value_of("custom"),
custom_configuration=CustomConfigurationResponse(
status=CustomConfigurationStatus.ACTIVE
if provider_configuration.is_custom_configuration_available()
else CustomConfigurationStatus.NO_CONFIGURE
),
system_configuration=SystemConfigurationResponse(enabled=False),
)
provider_responses.append(provider_response)
return provider_responses
def get_models_by_model_type(
self, model_type: str
) -> List[ProviderWithModelsResponse]:
"""
get models by model type.
:param model_type: model type
:return:
"""
# 合并两个字典的键
provider = set(
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
)
provider.update(
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
)
# Get all provider configurations of the current workspace
provider_configurations = (
self.provider_manager.provider_manager.get_configurations(provider=provider)
)
# Get provider available models
models = provider_configurations.get_models(
model_type=ModelType.value_of(model_type)
)
# Group models by provider
provider_models = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
if model.deprecated:
continue
provider_models[model.provider.provider].append(model)
# convert to ProviderWithModelsResponse list
providers_with_models: list[ProviderWithModelsResponse] = []
for provider, models in provider_models.items():
if not models:
continue
first_model = models[0]
has_active_models = any(
[model.status == ModelStatus.ACTIVE for model in models]
)
providers_with_models.append(
ProviderWithModelsResponse(
provider=provider,
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE
if has_active_models
else CustomConfigurationStatus.NO_CONFIGURE,
models=[
ModelResponse(
model=model.model,
label=model.label,
model_type=model.model_type,
features=model.features,
fetch_from=model.fetch_from,
model_properties=model.model_properties,
status=model.status,
)
for model in models
],
)
)
return providers_with_models
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_config(cls, cfg=None): def from_config(cls, cfg=None):
@ -44,19 +197,9 @@ class Bootstrap:
class OpenAIBootstrapBaseWeb(Bootstrap): class OpenAIBootstrapBaseWeb(Bootstrap):
_provider_manager: ModelManager
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@property
def provider_manager(self) -> ModelManager:
return self._provider_manager
@provider_manager.setter
def provider_manager(self, provider_manager: ModelManager):
self._provider_manager = provider_manager
@abstractmethod @abstractmethod
async def list_models(self, provider: str, request: Request): async def list_models(self, provider: str, request: Request):
pass pass

View File

@ -66,6 +66,16 @@ class ProviderConfiguration(BaseModel):
else: else:
return None return None
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
return (
self.custom_configuration.provider is not None
or len(self.custom_configuration.models) > 0
)
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
""" """
Get custom credentials. Get custom credentials.

View File

@ -6,12 +6,82 @@ from pydantic import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType from model_providers.core.model_runtime.entities.model_entities import ModelType
class ProviderType(Enum):
CUSTOM = "custom"
SYSTEM = "system"
@staticmethod
def value_of(value):
for member in ProviderType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ProviderQuotaType(Enum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value):
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum):
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class SystemConfigurationStatus(Enum):
"""
Enum class for system configuration status.
"""
ACTIVE = "active"
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
class RestrictModel(BaseModel): class RestrictModel(BaseModel):
model: str model: str
base_model_name: Optional[str] = None base_model_name: Optional[str] = None
model_type: ModelType model_type: ModelType
class QuotaConfiguration(BaseModel):
"""
Model class for provider quota configuration.
"""
quota_type: ProviderQuotaType
quota_unit: QuotaUnit
quota_limit: int
quota_used: int
is_valid: bool
restrict_models: list[RestrictModel] = []
class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
credentials: Optional[dict] = None
class CustomProviderConfiguration(BaseModel): class CustomProviderConfiguration(BaseModel):
""" """
Model class for provider custom configuration. Model class for provider custom configuration.

View File

@ -1,7 +1,7 @@
import importlib import importlib
import logging import logging
import os import os
from typing import Optional from typing import Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -49,7 +49,9 @@ class ModelProviderFactory:
if init_cache: if init_cache:
self.get_providers() self.get_providers()
def get_providers(self, provider_name: str = "") -> list[ProviderEntity]: def get_providers(
self, provider_name: Union[str, set] = ""
) -> list[ProviderEntity]:
""" """
Get all providers Get all providers
:return: list of providers :return: list of providers
@ -60,20 +62,36 @@ class ModelProviderFactory:
# traverse all model_provider_extensions # traverse all model_provider_extensions
providers = [] providers = []
for name, model_provider_extension in model_provider_extensions.items(): for name, model_provider_extension in model_provider_extensions.items():
if provider_name in (name, ""): if isinstance(provider_name, str):
# get model_provider instance if provider_name in (name, ""):
model_provider_instance = model_provider_extension.provider_instance # get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
# get provider schema # get provider schema
provider_schema = model_provider_instance.get_provider_schema() provider_schema = model_provider_instance.get_provider_schema()
for model_type in provider_schema.supported_model_types: for model_type in provider_schema.supported_model_types:
# get predefined models for given model type # get predefined models for given model type
models = model_provider_instance.models(model_type) models = model_provider_instance.models(model_type)
if models: if models:
provider_schema.models.extend(models) provider_schema.models.extend(models)
providers.append(provider_schema) providers.append(provider_schema)
elif isinstance(provider_name, set):
if name in provider_name:
# get model_provider instance
model_provider_instance = model_provider_extension.provider_instance
# get provider schema
provider_schema = model_provider_instance.get_provider_schema()
for model_type in provider_schema.supported_model_types:
# get predefined models for given model type
models = model_provider_instance.models(model_type)
if models:
provider_schema.models.extend(models)
providers.append(provider_schema)
# return providers # return providers
return providers return providers

View File

@ -1,7 +1,7 @@
import json import json
from collections import defaultdict from collections import defaultdict
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional from typing import Optional, Union
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -45,7 +45,7 @@ class ProviderManager:
provider_name_to_provider_model_records_dict provider_name_to_provider_model_records_dict
) )
def get_configurations(self, provider: str) -> ProviderConfigurations: def get_configurations(self, provider: Union[str, set]) -> ProviderConfigurations:
""" """
Get model provider configurations. Get model provider configurations.
@ -155,7 +155,7 @@ class ProviderManager:
default_model = {} default_model = {}
# Get provider configurations # Get provider configurations
provider_configurations = self.get_configurations() provider_configurations = self.get_configurations(provider="openai")
# get available models from provider_configurations # get available models from provider_configurations
available_models = provider_configurations.get_models( available_models = provider_configurations.get_models(