mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 07:23:29 +08:00
provider_configuration.py
查询所有的平台信息,包含计费策略和配置schema_validators(参数必填信息校验规则)
/workspaces/current/model-providers
查询平台模型分类的详细默认信息,包含了模型类型,模型参数,模型状态
workspaces/current/models/model-types/{model_type}
This commit is contained in:
parent
bfcf2775f5
commit
a1fe8d714f
@ -0,0 +1,183 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from model_providers.core.entities.model_entities import (
|
||||||
|
ModelStatus,
|
||||||
|
ModelWithProviderEntity,
|
||||||
|
)
|
||||||
|
from model_providers.core.entities.provider_entities import (
|
||||||
|
ProviderQuotaType,
|
||||||
|
ProviderType,
|
||||||
|
QuotaConfiguration,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from model_providers.core.model_runtime.entities.model_entities import (
|
||||||
|
ModelType,
|
||||||
|
ProviderModel,
|
||||||
|
)
|
||||||
|
from model_providers.core.model_runtime.entities.provider_entities import (
|
||||||
|
ConfigurateMethod,
|
||||||
|
ModelCredentialSchema,
|
||||||
|
ProviderCredentialSchema,
|
||||||
|
ProviderHelpEntity,
|
||||||
|
SimpleProviderEntity,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomConfigurationStatus(Enum):
|
||||||
|
"""
|
||||||
|
Enum class for custom configuration status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ACTIVE = "active"
|
||||||
|
NO_CONFIGURE = "no-configure"
|
||||||
|
|
||||||
|
|
||||||
|
class CustomConfigurationResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider custom configuration response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: CustomConfigurationStatus
|
||||||
|
|
||||||
|
|
||||||
|
class SystemConfigurationResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider system configuration response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool
|
||||||
|
current_quota_type: Optional[ProviderQuotaType] = None
|
||||||
|
quota_configurations: list[QuotaConfiguration] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
label: I18nObject
|
||||||
|
description: Optional[I18nObject] = None
|
||||||
|
icon_small: Optional[I18nObject] = None
|
||||||
|
icon_large: Optional[I18nObject] = None
|
||||||
|
background: Optional[str] = None
|
||||||
|
help: Optional[ProviderHelpEntity] = None
|
||||||
|
supported_model_types: list[ModelType]
|
||||||
|
configurate_methods: list[ConfigurateMethod]
|
||||||
|
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||||
|
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||||
|
preferred_provider_type: ProviderType
|
||||||
|
custom_configuration: CustomConfigurationResponse
|
||||||
|
system_configuration: SystemConfigurationResponse
|
||||||
|
|
||||||
|
def __init__(self, **data) -> None:
|
||||||
|
super().__init__(**data)
|
||||||
|
#
|
||||||
|
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||||
|
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||||
|
# if self.icon_small is not None:
|
||||||
|
# self.icon_small = I18nObject(
|
||||||
|
# en_US=f"{url_prefix}/icon_small/en_US",
|
||||||
|
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# if self.icon_large is not None:
|
||||||
|
# self.icon_large = I18nObject(
|
||||||
|
# en_US=f"{url_prefix}/icon_large/en_US",
|
||||||
|
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderListResponse(BaseModel):
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: List[ProviderResponse] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponse(ProviderModel):
|
||||||
|
"""
|
||||||
|
Model class for model response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: ModelStatus
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderWithModelsResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider with models response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
label: I18nObject
|
||||||
|
icon_small: Optional[I18nObject] = None
|
||||||
|
icon_large: Optional[I18nObject] = None
|
||||||
|
status: CustomConfigurationStatus
|
||||||
|
models: list[ModelResponse]
|
||||||
|
|
||||||
|
def __init__(self, **data) -> None:
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||||
|
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||||
|
# if self.icon_small is not None:
|
||||||
|
# self.icon_small = I18nObject(
|
||||||
|
# en_US=f"{url_prefix}/icon_small/en_US",
|
||||||
|
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# if self.icon_large is not None:
|
||||||
|
# self.icon_large = I18nObject(
|
||||||
|
# en_US=f"{url_prefix}/icon_large/en_US",
|
||||||
|
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderModelTypeResponse(BaseModel):
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: List[ProviderWithModelsResponse] = []
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||||
|
"""
|
||||||
|
Simple provider entity response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **data) -> None:
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
# url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||||
|
# + f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||||
|
# if self.icon_small is not None:
|
||||||
|
# self.icon_small = I18nObject(
|
||||||
|
# en_US=f"{url_prefix}/icon_small/en_US",
|
||||||
|
# zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# if self.icon_large is not None:
|
||||||
|
# self.icon_large = I18nObject(
|
||||||
|
# en_US=f"{url_prefix}/icon_large/en_US",
|
||||||
|
# zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultModelResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Default model entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
model_type: ModelType
|
||||||
|
provider: SimpleProviderEntityResponse
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||||
|
"""
|
||||||
|
Model with provider entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: SimpleProviderEntityResponse
|
||||||
|
|
||||||
|
def __init__(self, model: ModelWithProviderEntity) -> None:
|
||||||
|
super().__init__(**model.dict())
|
||||||
@ -24,6 +24,10 @@ from sse_starlette import EventSourceResponse
|
|||||||
from uvicorn import Config, Server
|
from uvicorn import Config, Server
|
||||||
|
|
||||||
from model_providers.bootstrap_web.common import create_stream_chunk
|
from model_providers.bootstrap_web.common import create_stream_chunk
|
||||||
|
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||||
|
ProviderListResponse,
|
||||||
|
ProviderModelTypeResponse,
|
||||||
|
)
|
||||||
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
from model_providers.core.bootstrap import OpenAIBootstrapBaseWeb
|
||||||
from model_providers.core.bootstrap.openai_protocol import (
|
from model_providers.core.bootstrap.openai_protocol import (
|
||||||
ChatCompletionMessage,
|
ChatCompletionMessage,
|
||||||
@ -301,6 +305,18 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._router.add_api_route(
|
||||||
|
"/workspaces/current/model-providers",
|
||||||
|
self.workspaces_model_providers,
|
||||||
|
response_model=ProviderListResponse,
|
||||||
|
methods=["GET"],
|
||||||
|
)
|
||||||
|
self._router.add_api_route(
|
||||||
|
"/workspaces/current/models/model-types/{model_type}",
|
||||||
|
self.workspaces_model_types,
|
||||||
|
response_model=ProviderModelTypeResponse,
|
||||||
|
methods=["GET"],
|
||||||
|
)
|
||||||
self._router.add_api_route(
|
self._router.add_api_route(
|
||||||
"/{provider}/v1/models",
|
"/{provider}/v1/models",
|
||||||
self.list_models,
|
self.list_models,
|
||||||
@ -345,6 +361,14 @@ class RESTFulOpenAIBootstrapBaseWeb(OpenAIBootstrapBaseWeb):
|
|||||||
if started_event is not None:
|
if started_event is not None:
|
||||||
started_event.set()
|
started_event.set()
|
||||||
|
|
||||||
|
async def workspaces_model_providers(self, request: Request):
|
||||||
|
provider_list = self.get_provider_list(model_type=request.get("model_type"))
|
||||||
|
return ProviderListResponse(data=provider_list)
|
||||||
|
|
||||||
|
async def workspaces_model_types(self, model_type: str, request: Request):
|
||||||
|
models_by_model_type = self.get_models_by_model_type(model_type=model_type)
|
||||||
|
return ProviderModelTypeResponse(data=models_by_model_type)
|
||||||
|
|
||||||
async def list_models(self, provider: str, request: Request):
|
async def list_models(self, provider: str, request: Request):
|
||||||
logger.info(f"Received list_models request for provider: {provider}")
|
logger.info(f"Received list_models request for provider: {provider}")
|
||||||
# 返回ModelType所有的枚举
|
# 返回ModelType所有的枚举
|
||||||
|
|||||||
@ -1,13 +1,25 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
from model_providers.bootstrap_web.entities.model_provider_entities import (
|
||||||
|
CustomConfigurationResponse,
|
||||||
|
CustomConfigurationStatus,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderResponse,
|
||||||
|
ProviderWithModelsResponse,
|
||||||
|
SystemConfigurationResponse,
|
||||||
|
)
|
||||||
from model_providers.core.bootstrap.openai_protocol import (
|
from model_providers.core.bootstrap.openai_protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
)
|
)
|
||||||
|
from model_providers.core.entities.model_entities import ModelStatus
|
||||||
|
from model_providers.core.entities.provider_entities import ProviderType
|
||||||
from model_providers.core.model_manager import ModelManager
|
from model_providers.core.model_manager import ModelManager
|
||||||
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
||||||
|
|
||||||
class Bootstrap:
|
class Bootstrap:
|
||||||
@ -18,9 +30,150 @@ class Bootstrap:
|
|||||||
"""任务队列"""
|
"""任务队列"""
|
||||||
_QUEUE: deque = deque()
|
_QUEUE: deque = deque()
|
||||||
|
|
||||||
|
_provider_manager: ModelManager
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._version = "v0.0.1"
|
self._version = "v0.0.1"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_manager(self) -> ModelManager:
|
||||||
|
return self._provider_manager
|
||||||
|
|
||||||
|
@provider_manager.setter
|
||||||
|
def provider_manager(self, provider_manager: ModelManager):
|
||||||
|
self._provider_manager = provider_manager
|
||||||
|
|
||||||
|
def get_provider_list(
|
||||||
|
self, model_type: Optional[str] = None
|
||||||
|
) -> List[ProviderResponse]:
|
||||||
|
"""
|
||||||
|
get provider list.
|
||||||
|
|
||||||
|
:param model_type: model type
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 合并两个字典的键
|
||||||
|
provider = set(
|
||||||
|
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||||
|
)
|
||||||
|
provider.update(
|
||||||
|
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||||
|
)
|
||||||
|
# Get all provider configurations of the current workspace
|
||||||
|
provider_configurations = (
|
||||||
|
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_responses = []
|
||||||
|
for provider_configuration in provider_configurations.values():
|
||||||
|
if model_type:
|
||||||
|
model_type_entity = ModelType.value_of(model_type)
|
||||||
|
if (
|
||||||
|
model_type_entity
|
||||||
|
not in provider_configuration.provider.supported_model_types
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
provider_response = ProviderResponse(
|
||||||
|
provider=provider_configuration.provider.provider,
|
||||||
|
label=provider_configuration.provider.label,
|
||||||
|
description=provider_configuration.provider.description,
|
||||||
|
icon_small=provider_configuration.provider.icon_small,
|
||||||
|
icon_large=provider_configuration.provider.icon_large,
|
||||||
|
background=provider_configuration.provider.background,
|
||||||
|
help=provider_configuration.provider.help,
|
||||||
|
supported_model_types=provider_configuration.provider.supported_model_types,
|
||||||
|
configurate_methods=provider_configuration.provider.configurate_methods,
|
||||||
|
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
|
||||||
|
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
||||||
|
preferred_provider_type=ProviderType.value_of("custom"),
|
||||||
|
custom_configuration=CustomConfigurationResponse(
|
||||||
|
status=CustomConfigurationStatus.ACTIVE
|
||||||
|
if provider_configuration.is_custom_configuration_available()
|
||||||
|
else CustomConfigurationStatus.NO_CONFIGURE
|
||||||
|
),
|
||||||
|
system_configuration=SystemConfigurationResponse(enabled=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_responses.append(provider_response)
|
||||||
|
|
||||||
|
return provider_responses
|
||||||
|
|
||||||
|
def get_models_by_model_type(
|
||||||
|
self, model_type: str
|
||||||
|
) -> List[ProviderWithModelsResponse]:
|
||||||
|
"""
|
||||||
|
get models by model type.
|
||||||
|
|
||||||
|
:param model_type: model type
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 合并两个字典的键
|
||||||
|
provider = set(
|
||||||
|
self.provider_manager.provider_manager.provider_name_to_provider_records_dict.keys()
|
||||||
|
)
|
||||||
|
provider.update(
|
||||||
|
self.provider_manager.provider_manager.provider_name_to_provider_model_records_dict.keys()
|
||||||
|
)
|
||||||
|
# Get all provider configurations of the current workspace
|
||||||
|
provider_configurations = (
|
||||||
|
self.provider_manager.provider_manager.get_configurations(provider=provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get provider available models
|
||||||
|
models = provider_configurations.get_models(
|
||||||
|
model_type=ModelType.value_of(model_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group models by provider
|
||||||
|
provider_models = {}
|
||||||
|
for model in models:
|
||||||
|
if model.provider.provider not in provider_models:
|
||||||
|
provider_models[model.provider.provider] = []
|
||||||
|
|
||||||
|
if model.deprecated:
|
||||||
|
continue
|
||||||
|
|
||||||
|
provider_models[model.provider.provider].append(model)
|
||||||
|
|
||||||
|
# convert to ProviderWithModelsResponse list
|
||||||
|
providers_with_models: list[ProviderWithModelsResponse] = []
|
||||||
|
for provider, models in provider_models.items():
|
||||||
|
if not models:
|
||||||
|
continue
|
||||||
|
|
||||||
|
first_model = models[0]
|
||||||
|
|
||||||
|
has_active_models = any(
|
||||||
|
[model.status == ModelStatus.ACTIVE for model in models]
|
||||||
|
)
|
||||||
|
|
||||||
|
providers_with_models.append(
|
||||||
|
ProviderWithModelsResponse(
|
||||||
|
provider=provider,
|
||||||
|
label=first_model.provider.label,
|
||||||
|
icon_small=first_model.provider.icon_small,
|
||||||
|
icon_large=first_model.provider.icon_large,
|
||||||
|
status=CustomConfigurationStatus.ACTIVE
|
||||||
|
if has_active_models
|
||||||
|
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||||
|
models=[
|
||||||
|
ModelResponse(
|
||||||
|
model=model.model,
|
||||||
|
label=model.label,
|
||||||
|
model_type=model.model_type,
|
||||||
|
features=model.features,
|
||||||
|
fetch_from=model.fetch_from,
|
||||||
|
model_properties=model.model_properties,
|
||||||
|
status=model.status,
|
||||||
|
)
|
||||||
|
for model in models
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return providers_with_models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_config(cls, cfg=None):
|
def from_config(cls, cfg=None):
|
||||||
@ -44,19 +197,9 @@ class Bootstrap:
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIBootstrapBaseWeb(Bootstrap):
|
class OpenAIBootstrapBaseWeb(Bootstrap):
|
||||||
_provider_manager: ModelManager
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_manager(self) -> ModelManager:
|
|
||||||
return self._provider_manager
|
|
||||||
|
|
||||||
@provider_manager.setter
|
|
||||||
def provider_manager(self, provider_manager: ModelManager):
|
|
||||||
self._provider_manager = provider_manager
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def list_models(self, provider: str, request: Request):
|
async def list_models(self, provider: str, request: Request):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -66,6 +66,16 @@ class ProviderConfiguration(BaseModel):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def is_custom_configuration_available(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check custom configuration available.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self.custom_configuration.provider is not None
|
||||||
|
or len(self.custom_configuration.models) > 0
|
||||||
|
)
|
||||||
|
|
||||||
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Get custom credentials.
|
Get custom credentials.
|
||||||
|
|||||||
@ -6,12 +6,82 @@ from pydantic import BaseModel
|
|||||||
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
from model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderType(Enum):
|
||||||
|
CUSTOM = "custom"
|
||||||
|
SYSTEM = "system"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in ProviderType:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderQuotaType(Enum):
|
||||||
|
PAID = "paid"
|
||||||
|
"""hosted paid quota"""
|
||||||
|
|
||||||
|
FREE = "free"
|
||||||
|
"""third-party free quota"""
|
||||||
|
|
||||||
|
TRIAL = "trial"
|
||||||
|
"""hosted trial quota"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in ProviderQuotaType:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaUnit(Enum):
|
||||||
|
TIMES = "times"
|
||||||
|
TOKENS = "tokens"
|
||||||
|
CREDITS = "credits"
|
||||||
|
|
||||||
|
|
||||||
|
class SystemConfigurationStatus(Enum):
|
||||||
|
"""
|
||||||
|
Enum class for system configuration status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ACTIVE = "active"
|
||||||
|
QUOTA_EXCEEDED = "quota-exceeded"
|
||||||
|
UNSUPPORTED = "unsupported"
|
||||||
|
|
||||||
|
|
||||||
class RestrictModel(BaseModel):
|
class RestrictModel(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
base_model_name: Optional[str] = None
|
base_model_name: Optional[str] = None
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider quota configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
quota_type: ProviderQuotaType
|
||||||
|
quota_unit: QuotaUnit
|
||||||
|
quota_limit: int
|
||||||
|
quota_used: int
|
||||||
|
is_valid: bool
|
||||||
|
restrict_models: list[RestrictModel] = []
|
||||||
|
|
||||||
|
|
||||||
|
class SystemConfiguration(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for provider system configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool
|
||||||
|
current_quota_type: Optional[ProviderQuotaType] = None
|
||||||
|
quota_configurations: list[QuotaConfiguration] = []
|
||||||
|
credentials: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class CustomProviderConfiguration(BaseModel):
|
class CustomProviderConfiguration(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for provider custom configuration.
|
Model class for provider custom configuration.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -49,7 +49,9 @@ class ModelProviderFactory:
|
|||||||
if init_cache:
|
if init_cache:
|
||||||
self.get_providers()
|
self.get_providers()
|
||||||
|
|
||||||
def get_providers(self, provider_name: str = "") -> list[ProviderEntity]:
|
def get_providers(
|
||||||
|
self, provider_name: Union[str, set] = ""
|
||||||
|
) -> list[ProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get all providers
|
Get all providers
|
||||||
:return: list of providers
|
:return: list of providers
|
||||||
@ -60,20 +62,36 @@ class ModelProviderFactory:
|
|||||||
# traverse all model_provider_extensions
|
# traverse all model_provider_extensions
|
||||||
providers = []
|
providers = []
|
||||||
for name, model_provider_extension in model_provider_extensions.items():
|
for name, model_provider_extension in model_provider_extensions.items():
|
||||||
if provider_name in (name, ""):
|
if isinstance(provider_name, str):
|
||||||
# get model_provider instance
|
if provider_name in (name, ""):
|
||||||
model_provider_instance = model_provider_extension.provider_instance
|
# get model_provider instance
|
||||||
|
model_provider_instance = model_provider_extension.provider_instance
|
||||||
|
|
||||||
# get provider schema
|
# get provider schema
|
||||||
provider_schema = model_provider_instance.get_provider_schema()
|
provider_schema = model_provider_instance.get_provider_schema()
|
||||||
|
|
||||||
for model_type in provider_schema.supported_model_types:
|
for model_type in provider_schema.supported_model_types:
|
||||||
# get predefined models for given model type
|
# get predefined models for given model type
|
||||||
models = model_provider_instance.models(model_type)
|
models = model_provider_instance.models(model_type)
|
||||||
if models:
|
if models:
|
||||||
provider_schema.models.extend(models)
|
provider_schema.models.extend(models)
|
||||||
|
|
||||||
providers.append(provider_schema)
|
providers.append(provider_schema)
|
||||||
|
elif isinstance(provider_name, set):
|
||||||
|
if name in provider_name:
|
||||||
|
# get model_provider instance
|
||||||
|
model_provider_instance = model_provider_extension.provider_instance
|
||||||
|
|
||||||
|
# get provider schema
|
||||||
|
provider_schema = model_provider_instance.get_provider_schema()
|
||||||
|
|
||||||
|
for model_type in provider_schema.supported_model_types:
|
||||||
|
# get predefined models for given model type
|
||||||
|
models = model_provider_instance.models(model_type)
|
||||||
|
if models:
|
||||||
|
provider_schema.models.extend(models)
|
||||||
|
|
||||||
|
providers.append(provider_schema)
|
||||||
|
|
||||||
# return providers
|
# return providers
|
||||||
return providers
|
return providers
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class ProviderManager:
|
|||||||
provider_name_to_provider_model_records_dict
|
provider_name_to_provider_model_records_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_configurations(self, provider: str) -> ProviderConfigurations:
|
def get_configurations(self, provider: Union[str, set]) -> ProviderConfigurations:
|
||||||
"""
|
"""
|
||||||
Get model provider configurations.
|
Get model provider configurations.
|
||||||
|
|
||||||
@ -155,7 +155,7 @@ class ProviderManager:
|
|||||||
|
|
||||||
default_model = {}
|
default_model = {}
|
||||||
# Get provider configurations
|
# Get provider configurations
|
||||||
provider_configurations = self.get_configurations()
|
provider_configurations = self.get_configurations(provider="openai")
|
||||||
|
|
||||||
# get available models from provider_configurations
|
# get available models from provider_configurations
|
||||||
available_models = provider_configurations.get_models(
|
available_models = provider_configurations.get_models(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user