This commit is contained in:
khazic 2024-03-26 14:49:26 +08:00
parent 596a0f5fa1
commit 8ebfb34a51
6 changed files with 921 additions and 22 deletions

View File

@ -1,37 +1,28 @@
import os
from typing import cast, Generator
from chatchat_model_providers.core.model_manager import ModelManager
from chatchat_model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from chatchat_model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType
if __name__ == '__main__':
# 基于配置管理器创建的模型实例
# provider_manager = ProviderManager()
provider_configurations = ProviderConfigurations(
tenant_id=tenant_id
provider_manager = ModelManager(
provider_name_to_provider_records_dict={
'openai': {
'openai_api_key': "sk- ",
}
},
provider_name_to_provider_model_records_dict={}
)
#
# model_instance = ModelInstance(
# provider_model_bundle=provider_model_bundle,
# model=model_config.model,
# )
# 直接通过模型加载器创建的模型实例
from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory
model_provider_factory.get_providers(provider_name='openai')
provider_instance = model_provider_factory.get_provider_instance('openai')
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
print(model_type_instance)
response = model_type_instance.invoke(
model='gpt-4',
credentials={
'openai_api_key': "sk-",
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
# Invoke model
model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4')
response = model_instance.invoke_llm(
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'

View File

@ -0,0 +1,257 @@
from collections.abc import Generator
from typing import IO, Optional, Union, cast
from model_providers.core.entities.provider_configuration import ProviderModelBundle
from model_providers.core.errors.error import ProviderTokenNotInitError
from model_providers.core.model_runtime.callbacks.base_callback import Callback
from model_providers.core.model_runtime.entities.llm_entities import LLMResult
from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.rerank_entities import RerankResult
from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
from model_providers.core.provider_manager import ProviderManager
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict:
"""
Fetch credentials from provider model bundle
:param provider_model_bundle: provider model bundle
:param model: model name
:return:
"""
credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=provider_model_bundle.model_type_instance.model_type,
model=model
)
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
return credentials
class ModelInstance:
"""
Model instance class
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
self._provider_model_bundle = provider_model_bundle
self.model = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = _fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self._provider_model_bundle.model_type_instance
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
texts=texts,
user=user
)
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user
)
def invoke_moderation(self, text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke moderation model
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
text=text,
user=user
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
file=file,
user=user
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
-> str:
"""
Invoke large language tts model
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param user: unique user id
:param voice: model timbre
:param streaming: output is streaming
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
streaming=streaming
)
def get_tts_voices(self, language: str) -> list:
"""
Invoke large language tts model voices
:param language: tts language
:return: tts model voices
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model,
credentials=self.credentials,
language=language
)
class ModelManager:
def __init__(self,
provider_name_to_provider_records_dict: dict,
provider_name_to_provider_model_records_dict: dict) -> None:
self._provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict)
def get_model_instance(self, provider: str, model_type: ModelType, model: str) -> ModelInstance:
"""
Get model instance
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
if not provider:
return self.get_default_model_instance(model_type)
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
provider=provider,
model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
def get_default_model_instance(self, model_type: ModelType) -> ModelInstance:
"""
Get default model instance
:param model_type: model type
:return:
"""
default_model_entity = self._provider_manager.get_default_model(
model_type=model_type
)
if not default_model_entity:
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
return self.get_model_instance(
provider=default_model_entity.provider.provider,
model_type=model_type,
model=default_model_entity.model
)

View File

@ -0,0 +1,256 @@
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Optional
from sqlalchemy.exc import IntegrityError
from model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, \
ProviderModelBundle
from model_providers.core.entities.provider_entities import (
CustomConfiguration,
CustomModelConfiguration,
CustomProviderConfiguration,
)
from model_providers.core.model_runtime.entities.model_entities import ModelType
from model_providers.core.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FormType,
ProviderEntity,
)
from model_providers.core.model_runtime.model_providers import model_provider_factory
class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
def __init__(self,
provider_name_to_provider_records_dict: dict,
provider_name_to_provider_model_records_dict: dict) -> None:
self.provider_name_to_provider_records_dict = provider_name_to_provider_records_dict
self.provider_name_to_provider_model_records_dict = provider_name_to_provider_model_records_dict
def get_configurations(self, provider: str) -> ProviderConfigurations:
"""
Get model provider configurations.
Construct ProviderConfiguration objects for each provider
Including:
1. Basic information of the provider
2. Hosting configuration information, including:
(1. Whether to enable (support) hosting type, if enabled, the following information exists
(2. List of hosting type provider configurations
(including quota type, quota limit, current remaining quota, etc.)
(3. The current hosting type in use (whether there is a quota or not)
paid quotas > provider free quotas > hosting trial quotas
(4. Unified credentials for hosting providers
3. Custom configuration information, including:
(1. Whether to enable (support) custom type, if enabled, the following information exists
(2. Custom provider configuration (including credentials)
(3. List of custom provider model configurations (including credentials)
4. Hosting/custom preferred provider type.
Provide methods:
- Get the current configuration (including credentials)
- Get the availability and status of the hosting configuration: active available,
quota_exceeded insufficient quota, unsupported hosting
- Get the availability of custom configuration
Custom provider available conditions:
(1. custom provider credentials available
(2. at least one custom model credentials available
- Verify, update, and delete custom provider configuration
- Verify, update, and delete custom provider model configuration
- Get the list of available models (optional provider filtering, model type filtering)
Append custom provider models to the list
- Get provider instance
- Switch selection priority
:param: provider
:return:
"""
# Get all provider entities
provider_entities = model_provider_factory.get_providers(provider_name=provider)
provider_configurations = ProviderConfigurations()
# Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities:
provider_name = provider_entity.provider
provider_credentials = self.provider_name_to_provider_records_dict.get(provider_entity.provider)
if not provider_credentials:
provider_credentials = {}
provider_model_records = self.provider_name_to_provider_model_records_dict.get(provider_entity.provider)
if not provider_model_records:
provider_model_records = []
# Convert to custom configuration
custom_configuration = self._to_custom_configuration(
provider_entity,
provider_credentials,
provider_model_records
)
provider_configuration = ProviderConfiguration(
provider=provider_entity,
custom_configuration=custom_configuration
)
provider_configurations[provider_name] = provider_configuration
# Return the encapsulated object
return provider_configurations
def get_provider_model_bundle(self, provider: str, model_type: ModelType) -> ProviderModelBundle:
"""
Get provider model bundle.
:param provider: provider name
:param model_type: model type
:return:
"""
provider_configurations = self.get_configurations(provider=provider)
# get provider instance
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
provider_instance = provider_configuration.get_provider_instance()
model_type_instance = provider_instance.get_model_instance(model_type)
return ProviderModelBundle(
configuration=provider_configuration,
provider_instance=provider_instance,
model_type_instance=model_type_instance
)
def get_default_model(self, model_type: ModelType) -> Optional[DefaultModelEntity]:
"""
Get default model.
:param model_type: model type
:return:
"""
default_model = {}
# Get provider configurations
provider_configurations = self.get_configurations()
# get available models from provider_configurations
available_models = provider_configurations.get_models(
model_type=model_type,
only_active=True
)
if available_models:
found = False
for available_model in available_models:
if available_model.model == "gpt-3.5-turbo-1106":
default_model = {
'provider_name': available_model.provider.provider,
'model_name': available_model.model
}
found = True
break
if not found:
available_model = available_models[0]
default_model = {
'provider_name': available_model.provider.provider,
'model_name': available_model.model
}
provider_instance = model_provider_factory.get_provider_instance(default_model.get('provider_name'))
provider_schema = provider_instance.get_provider_schema()
return DefaultModelEntity(
model=default_model.get("model_name"),
model_type=model_type,
provider=DefaultModelProviderEntity(
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types
)
)
def _to_custom_configuration(self,
provider_entity: ProviderEntity,
provider_credentials: dict,
provider_model_records: list[dict]) -> CustomConfiguration:
"""
Convert to custom configuration.
:param provider_entity: provider entity
:param provider_credentials: provider records_credentials
:param provider_model_records: provider model records_credentials
:return:
"""
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema else []
)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
try:
provider_credentials[variable] = provider_credentials.get(variable)
except ValueError:
pass
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials
)
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema else []
)
# Get custom provider model credentials
custom_model_configurations = []
for provider_model_record in provider_model_records:
if not provider_model_record.get('model_credentials'):
continue
provider_model_credentials = {}
for variable in model_credential_secret_variables:
if variable in provider_model_record.get('model_credentials'):
try:
provider_model_credentials[variable] = provider_model_record.get('model_credentials').get(
variable)
except ValueError:
pass
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.get('model_name'),
model_type=ModelType.value_of(provider_model_record.get('model_type')),
credentials=provider_model_credentials
)
)
return CustomConfiguration(
provider=custom_provider_configuration,
models=custom_model_configurations
)
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -0,0 +1,38 @@
from typing import Optional
class LLMError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
self.description = description
class LLMBadRequestError(LLMError):
"""Raised when the LLM returns bad request."""
description = "Bad Request"
class ProviderTokenNotInitError(Exception):
"""
Custom exception raised when the provider token is not initialized.
"""
description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception):
"""
Custom exception raised when the quota for a provider has been exceeded.
"""
description = "Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support
"""
description = "Model Currently Not Support"

View File

@ -0,0 +1,320 @@
import datetime
import json
import logging
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional
from pydantic import BaseModel
from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from model_providers.core.entities.provider_entities import CustomConfiguration
from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType
from model_providers.core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormType,
ProviderEntity,
)
from model_providers.core.model_runtime.model_providers import model_provider_factory
from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class ProviderConfiguration(BaseModel):
"""
Model class for provider configuration.
"""
provider: ProviderEntity
custom_configuration: CustomConfiguration
def __init__(self, **data):
super().__init__(**data)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
"""
Get current credentials.
:param model_type: model type
:param model: model name
:return:
"""
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
return model_configuration.credentials
if self.custom_configuration.provider:
return self.custom_configuration.provider.credentials
else:
return None
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
Get custom credentials.
:param obfuscated: obfuscated secret data in credentials
:return:
"""
if self.custom_configuration.provider is None:
return None
credentials = self.custom_configuration.provider.credentials
if not obfuscated:
return credentials
# Obfuscate provider credentials
copy_credentials = credentials.copy()
return copy_credentials
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-> Optional[dict]:
"""
Get custom model credentials.
:param model_type: model type
:param model: model name
:param obfuscated: obfuscated secret data in credentials
:return:
"""
if not self.custom_configuration.models:
return None
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
credentials = model_configuration.credentials
if not obfuscated:
return credentials
copy_credentials = credentials.copy()
# Obfuscate credentials
return copy_credentials
return None
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
:return:
"""
return model_provider_factory.get_provider_instance(self.provider.provider)
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
"""
Get current model type instance.
:param model_type: model type
:return:
"""
# Get provider instance
provider_instance = self.get_provider_instance()
# Get model instance of LLM
return provider_instance.get_model_instance(model_type)
def get_provider_model(self, model_type: ModelType,
model: str,
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
"""
Get provider model.
:param model_type: model type
:param model: model name
:param only_active: return active model only
:return:
"""
provider_models = self.get_provider_models(model_type, only_active)
for provider_model in provider_models:
if provider_model.model == model:
return provider_model
return None
def get_provider_models(self, model_type: Optional[ModelType] = None,
only_active: bool = False) -> list[ModelWithProviderEntity]:
"""
Get provider models.
:param model_type: model type
:param only_active: only active models
:return:
"""
provider_instance = self.get_provider_instance()
model_types = []
if model_type:
model_types.append(model_type)
else:
model_types = provider_instance.get_provider_schema().supported_model_types
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_instance=provider_instance
)
if only_active:
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
# resort provider_models
return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_custom_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
:param model_types: model types
:param provider_instance: provider instance
:return:
"""
provider_models = []
credentials = None
if self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
for model_type in model_types:
if model_type not in self.provider.supported_model_types:
continue
models = provider_instance.models(model_type)
for m in models:
provider_models.append(
ModelWithProviderEntity(
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
)
)
# custom models
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
try:
custom_model_schema = (
provider_instance.get_model_instance(model_configuration.model_type)
.get_customizable_model_schema_from_credentials(
model_configuration.model,
model_configuration.credentials
)
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
continue
if not custom_model_schema:
continue
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=custom_model_schema.fetch_from,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
)
return provider_models
class ProviderConfigurations(BaseModel):
"""
Model class for provider configuration dict.
"""
configurations: dict[str, ProviderConfiguration] = {}
def __init__(self):
super().__init__()
def get_models(self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
only_active: bool = False) \
-> list[ModelWithProviderEntity]:
"""
Get available models.
If preferred provider type is `system`:
Get the current **system mode** if provider supported,
if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
If there is no model configured in custom mode, it is treated as no_configure.
system > custom > no_configure
If preferred provider type is `custom`:
If custom credentials are configured, it is treated as custom mode.
Otherwise, get the current **system mode** if supported,
If all system modes are not available (no quota), it is treated as no_configure.
custom > system > no_configure
If real mode is `system`, use system credentials to get models,
paid quotas > provider free quotas > system free quotas
include pre-defined models (exclude GPT-4, status marked as `no_permission`).
If real mode is `custom`, use workspace custom credentials to get models,
include pre-defined models, custom models(manual append).
If real mode is `no_configure`, only return pre-defined models from `model runtime`.
(model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
model status marked as `active` is available.
:param provider: provider name
:param model_type: model type
:param only_active: only active models
:return:
"""
all_models = []
for provider_configuration in self.values():
if provider and provider_configuration.provider.provider != provider:
continue
all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
return all_models
def to_list(self) -> list[ProviderConfiguration]:
"""
Convert to list.
:return:
"""
return list(self.values())
def __getitem__(self, key):
return self.configurations[key]
def __setitem__(self, key, value):
self.configurations[key] = value
def __iter__(self):
return iter(self.configurations)
def values(self) -> Iterator[ProviderConfiguration]:
return self.configurations.values()
def get(self, key, default=None):
return self.configurations.get(key, default)
class ProviderModelBundle(BaseModel):
"""
Provider model bundle.
"""
configuration: ProviderConfiguration
provider_instance: ModelProvider
model_type_instance: AIModel
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True

View File

@ -0,0 +1,37 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from model_providers.core.model_runtime.entities.model_entities import ModelType
class RestrictModel(BaseModel):
model: str
base_model_name: Optional[str] = None
model_type: ModelType
class CustomProviderConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
credentials: dict
class CustomModelConfiguration(BaseModel):
"""
Model class for provider custom model configuration.
"""
model: str
model_type: ModelType
credentials: dict
class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
"""
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []