From 8ebfb34a51417c5c552d00e77a5c6218731aaa63 Mon Sep 17 00:00:00 2001 From: khazic Date: Tue, 26 Mar 2024 14:49:26 +0800 Subject: [PATCH 1/3] 20240326 --- .../chatchat_model_providers/__main__.py | 35 +- .../core/model_manager.py | 257 ++++++++++++++ .../core/provider_manager.py | 256 ++++++++++++++ .../chatchat_model_providers/errors/error.py | 38 +++ .../extensions/provider_configuration.py | 320 ++++++++++++++++++ .../extensions/provider_entities.py | 37 ++ 6 files changed, 921 insertions(+), 22 deletions(-) create mode 100644 model-providers/chatchat_model_providers/core/model_manager.py create mode 100644 model-providers/chatchat_model_providers/core/provider_manager.py create mode 100644 model-providers/chatchat_model_providers/errors/error.py create mode 100644 model-providers/chatchat_model_providers/extensions/provider_configuration.py create mode 100644 model-providers/chatchat_model_providers/extensions/provider_entities.py diff --git a/model-providers/chatchat_model_providers/__main__.py b/model-providers/chatchat_model_providers/__main__.py index fca23638..318b7e5f 100644 --- a/model-providers/chatchat_model_providers/__main__.py +++ b/model-providers/chatchat_model_providers/__main__.py @@ -1,37 +1,28 @@ import os from typing import cast, Generator +from chatchat_model_providers.core.model_manager import ModelManager from chatchat_model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from chatchat_model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType if __name__ == '__main__': # 基于配置管理器创建的模型实例 - # provider_manager = ProviderManager() - - provider_configurations = ProviderConfigurations( - tenant_id=tenant_id + provider_manager = ModelManager( + provider_name_to_provider_records_dict={ + 'openai': { + 'openai_api_key': "sk- ", + } + }, + provider_name_to_provider_model_records_dict={} ) - # - # model_instance = ModelInstance( - # provider_model_bundle=provider_model_bundle, - # model=model_config.model, - # ) - # 直接通过模型加载器创建的模型实例 - from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory - model_provider_factory.get_providers(provider_name='openai') - provider_instance = model_provider_factory.get_provider_instance('openai') - model_type_instance = provider_instance.get_model_instance(ModelType.LLM) - print(model_type_instance) - response = model_type_instance.invoke( - model='gpt-4', - credentials={ - 'openai_api_key': "sk-", - 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'), - 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID') - }, + # Invoke model + model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4') + + response = model_instance.invoke_llm( + prompt_messages=[ UserPromptMessage( content='北京今天的天气怎么样' diff --git a/model-providers/chatchat_model_providers/core/model_manager.py b/model-providers/chatchat_model_providers/core/model_manager.py new file mode 100644 index 00000000..f7e9be10 --- /dev/null +++ b/model-providers/chatchat_model_providers/core/model_manager.py @@ -0,0 +1,257 @@ +from collections.abc import Generator +from typing import IO, Optional, Union, cast + +from model_providers.core.entities.provider_configuration import ProviderModelBundle +from model_providers.core.errors.error import ProviderTokenNotInitError +from model_providers.core.model_runtime.callbacks.base_callback import Callback +from model_providers.core.model_runtime.entities.llm_entities import LLMResult +from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.model_runtime.entities.rerank_entities import RerankResult +from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel +from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel +from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel +from model_providers.core.provider_manager import ProviderManager + + +def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: + """ + Fetch credentials from provider model bundle + :param provider_model_bundle: provider model bundle + :param model: model name + :return: + """ + credentials = provider_model_bundle.configuration.get_current_credentials( + model_type=provider_model_bundle.model_type_instance.model_type, + model=model + ) + + if credentials is None: + raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") + + return credentials + + +class ModelInstance: + """ + Model instance class + """ + + def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: + self._provider_model_bundle = provider_model_bundle + self.model = model + self.provider = provider_model_bundle.configuration.provider.provider + self.credentials = _fetch_credentials_from_bundle(provider_model_bundle, model) + self.model_type_instance = self._provider_model_bundle.model_type_instance + + def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :param callbacks: callbacks + :return: full response or stream response chunk generator result + """ + if not isinstance(self.model_type_instance, LargeLanguageModel): + raise Exception("Model type instance is not LargeLanguageModel") + + self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) + return self.model_type_instance.invoke( + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks + ) + + def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke large language model + + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + if not isinstance(self.model_type_instance, TextEmbeddingModel): + raise Exception("Model type instance is not TextEmbeddingModel") + + self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) + return self.model_type_instance.invoke( + model=self.model, + credentials=self.credentials, + texts=texts, + user=user + ) + + def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None) \ + -> RerankResult: + """ + Invoke rerank model + + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if not isinstance(self.model_type_instance, RerankModel): + raise Exception("Model type instance is not RerankModel") + + self.model_type_instance = cast(RerankModel, self.model_type_instance) + return self.model_type_instance.invoke( + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user + ) + + def invoke_moderation(self, text: str, user: Optional[str] = None) \ + -> bool: + """ + Invoke moderation model + + :param text: text to moderate + :param user: unique user id + :return: false if text is safe, true otherwise + """ + if not isinstance(self.model_type_instance, ModerationModel): + raise Exception("Model type instance is not ModerationModel") + + self.model_type_instance = cast(ModerationModel, self.model_type_instance) + return self.model_type_instance.invoke( + model=self.model, + credentials=self.credentials, + text=text, + user=user + ) + + def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ + -> str: + """ + Invoke large language model + + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + if not isinstance(self.model_type_instance, Speech2TextModel): + raise Exception("Model type instance is not Speech2TextModel") + + self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) + return self.model_type_instance.invoke( + model=self.model, + credentials=self.credentials, + file=file, + user=user + ) + + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \ + -> str: + """ + Invoke large language tts model + + :param content_text: text content to be translated + :param tenant_id: user tenant id + :param user: unique user id + :param voice: model timbre + :param streaming: output is streaming + :return: text for given audio file + """ + if not isinstance(self.model_type_instance, TTSModel): + raise Exception("Model type instance is not TTSModel") + + self.model_type_instance = cast(TTSModel, self.model_type_instance) + return self.model_type_instance.invoke( + model=self.model, + credentials=self.credentials, + content_text=content_text, + user=user, + tenant_id=tenant_id, + voice=voice, + streaming=streaming + ) + + def get_tts_voices(self, language: str) -> list: + """ + Invoke large language tts model voices + + :param language: tts language + :return: tts model voices + """ + if not isinstance(self.model_type_instance, TTSModel): + raise Exception("Model type instance is not TTSModel") + + self.model_type_instance = cast(TTSModel, self.model_type_instance) + return self.model_type_instance.get_tts_model_voices( + model=self.model, + credentials=self.credentials, + language=language + ) + + +class ModelManager: + def __init__(self, + provider_name_to_provider_records_dict: dict, + provider_name_to_provider_model_records_dict: dict) -> None: + self._provider_manager = ProviderManager( + provider_name_to_provider_records_dict=provider_name_to_provider_records_dict, + provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict) + + def get_model_instance(self, provider: str, model_type: ModelType, model: str) -> ModelInstance: + """ + Get model instance + :param provider: provider name + :param model_type: model type + :param model: model name + :return: + """ + if not provider: + return self.get_default_model_instance(model_type) + provider_model_bundle = self._provider_manager.get_provider_model_bundle( + provider=provider, + model_type=model_type + ) + + return ModelInstance(provider_model_bundle, model) + + def get_default_model_instance(self, model_type: ModelType) -> ModelInstance: + """ + Get default model instance + :param model_type: model type + :return: + """ + default_model_entity = self._provider_manager.get_default_model( + model_type=model_type + ) + + if not default_model_entity: + raise ProviderTokenNotInitError(f"Default model not found for {model_type}") + + return self.get_model_instance( + provider=default_model_entity.provider.provider, + model_type=model_type, + model=default_model_entity.model + ) diff --git a/model-providers/chatchat_model_providers/core/provider_manager.py b/model-providers/chatchat_model_providers/core/provider_manager.py new file mode 100644 index 00000000..a2016963 --- /dev/null +++ b/model-providers/chatchat_model_providers/core/provider_manager.py @@ -0,0 +1,256 @@ +import json +from collections import defaultdict +from json import JSONDecodeError +from typing import Optional + +from sqlalchemy.exc import IntegrityError + +from model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity +from model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, \ + ProviderModelBundle +from model_providers.core.entities.provider_entities import ( + CustomConfiguration, + CustomModelConfiguration, + CustomProviderConfiguration, +) +from model_providers.core.model_runtime.entities.model_entities import ModelType +from model_providers.core.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FormType, + ProviderEntity, +) +from model_providers.core.model_runtime.model_providers import model_provider_factory + + +class ProviderManager: + """ + ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. + """ + + def __init__(self, + provider_name_to_provider_records_dict: dict, + provider_name_to_provider_model_records_dict: dict) -> None: + self.provider_name_to_provider_records_dict = provider_name_to_provider_records_dict + self.provider_name_to_provider_model_records_dict = provider_name_to_provider_model_records_dict + + def get_configurations(self, provider: str) -> ProviderConfigurations: + """ + Get model provider configurations. + + Construct ProviderConfiguration objects for each provider + Including: + 1. Basic information of the provider + 2. Hosting configuration information, including: + (1. Whether to enable (support) hosting type, if enabled, the following information exists + (2. List of hosting type provider configurations + (including quota type, quota limit, current remaining quota, etc.) + (3. The current hosting type in use (whether there is a quota or not) + paid quotas > provider free quotas > hosting trial quotas + (4. Unified credentials for hosting providers + 3. Custom configuration information, including: + (1. Whether to enable (support) custom type, if enabled, the following information exists + (2. Custom provider configuration (including credentials) + (3. List of custom provider model configurations (including credentials) + 4. Hosting/custom preferred provider type. + Provide methods: + - Get the current configuration (including credentials) + - Get the availability and status of the hosting configuration: active available, + quota_exceeded insufficient quota, unsupported hosting + - Get the availability of custom configuration + Custom provider available conditions: + (1. custom provider credentials available + (2. at least one custom model credentials available + - Verify, update, and delete custom provider configuration + - Verify, update, and delete custom provider model configuration + - Get the list of available models (optional provider filtering, model type filtering) + Append custom provider models to the list + - Get provider instance + - Switch selection priority + + :param: provider + :return: + """ + + # Get all provider entities + provider_entities = model_provider_factory.get_providers(provider_name=provider) + + provider_configurations = ProviderConfigurations() + + # Construct ProviderConfiguration objects for each provider + for provider_entity in provider_entities: + provider_name = provider_entity.provider + + provider_credentials = self.provider_name_to_provider_records_dict.get(provider_entity.provider) + if not provider_credentials: + provider_credentials = {} + + provider_model_records = self.provider_name_to_provider_model_records_dict.get(provider_entity.provider) + if not provider_model_records: + provider_model_records = [] + + # Convert to custom configuration + custom_configuration = self._to_custom_configuration( + provider_entity, + provider_credentials, + provider_model_records + ) + + provider_configuration = ProviderConfiguration( + provider=provider_entity, + custom_configuration=custom_configuration + ) + + provider_configurations[provider_name] = provider_configuration + + # Return the encapsulated object + return provider_configurations + + def get_provider_model_bundle(self, provider: str, model_type: ModelType) -> ProviderModelBundle: + """ + Get provider model bundle. + :param provider: provider name + :param model_type: model type + :return: + """ + provider_configurations = self.get_configurations(provider=provider) + + # get provider instance + provider_configuration = provider_configurations.get(provider) + if not provider_configuration: + raise ValueError(f"Provider {provider} does not exist.") + + provider_instance = provider_configuration.get_provider_instance() + model_type_instance = provider_instance.get_model_instance(model_type) + + return ProviderModelBundle( + configuration=provider_configuration, + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + + def get_default_model(self, model_type: ModelType) -> Optional[DefaultModelEntity]: + """ + Get default model. + + :param model_type: model type + :return: + """ + + default_model = {} + # Get provider configurations + provider_configurations = self.get_configurations() + + # get available models from provider_configurations + available_models = provider_configurations.get_models( + model_type=model_type, + only_active=True + ) + + if available_models: + found = False + for available_model in available_models: + if available_model.model == "gpt-3.5-turbo-1106": + default_model = { + 'provider_name': available_model.provider.provider, + 'model_name': available_model.model + } + found = True + break + + if not found: + available_model = available_models[0] + default_model = { + 'provider_name': available_model.provider.provider, + 'model_name': available_model.model + } + + provider_instance = model_provider_factory.get_provider_instance(default_model.get('provider_name')) + provider_schema = provider_instance.get_provider_schema() + + return DefaultModelEntity( + model=default_model.get("model_name"), + model_type=model_type, + provider=DefaultModelProviderEntity( + provider=provider_schema.provider, + label=provider_schema.label, + icon_small=provider_schema.icon_small, + icon_large=provider_schema.icon_large, + supported_model_types=provider_schema.supported_model_types + ) + ) + + def _to_custom_configuration(self, + provider_entity: ProviderEntity, + provider_credentials: dict, + provider_model_records: list[dict]) -> CustomConfiguration: + """ + Convert to custom configuration. + + :param provider_entity: provider entity + :param provider_credentials: provider records_credentials + :param provider_model_records: provider model records_credentials + :return: + """ + # Get provider credential secret variables + provider_credential_secret_variables = self._extract_secret_variables( + provider_entity.provider_credential_schema.credential_form_schemas + if provider_entity.provider_credential_schema else [] + ) + + for variable in provider_credential_secret_variables: + if variable in provider_credentials: + try: + provider_credentials[variable] = provider_credentials.get(variable) + except ValueError: + pass + custom_provider_configuration = CustomProviderConfiguration( + credentials=provider_credentials + ) + + # Get provider model credential secret variables + model_credential_secret_variables = self._extract_secret_variables( + provider_entity.model_credential_schema.credential_form_schemas + if provider_entity.model_credential_schema else [] + ) + + # Get custom provider model credentials + custom_model_configurations = [] + for provider_model_record in provider_model_records: + if not provider_model_record.get('model_credentials'): + continue + + provider_model_credentials = {} + for variable in model_credential_secret_variables: + if variable in provider_model_record.get('model_credentials'): + try: + provider_model_credentials[variable] = provider_model_record.get('model_credentials').get( + variable) + except ValueError: + pass + + custom_model_configurations.append( + CustomModelConfiguration( + model=provider_model_record.get('model_name'), + model_type=ModelType.value_of(provider_model_record.get('model_type')), + credentials=provider_model_credentials + ) + ) + + return CustomConfiguration( + provider=custom_provider_configuration, + models=custom_model_configurations + ) + + def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: + """ + Extract secret input form variables. + + :param credential_form_schemas: + :return: + """ + secret_input_form_variables = [] + for credential_form_schema in credential_form_schemas: + if credential_form_schema.type == FormType.SECRET_INPUT: + secret_input_form_variables.append(credential_form_schema.variable) + + return secret_input_form_variables diff --git a/model-providers/chatchat_model_providers/errors/error.py b/model-providers/chatchat_model_providers/errors/error.py new file mode 100644 index 00000000..6ac95b39 --- /dev/null +++ b/model-providers/chatchat_model_providers/errors/error.py @@ -0,0 +1,38 @@ +from typing import Optional + + +class LLMError(Exception): + """Base class for all LLM exceptions.""" + description: Optional[str] = None + + def __init__(self, description: Optional[str] = None) -> None: + self.description = description + + +class LLMBadRequestError(LLMError): + """Raised when the LLM returns bad request.""" + description = "Bad Request" + + +class ProviderTokenNotInitError(Exception): + """ + Custom exception raised when the provider token is not initialized. + """ + description = "Provider Token Not Init" + + def __init__(self, *args, **kwargs): + self.description = args[0] if args else self.description + + +class QuotaExceededError(Exception): + """ + Custom exception raised when the quota for a provider has been exceeded. + """ + description = "Quota Exceeded" + + +class ModelCurrentlyNotSupportError(Exception): + """ + Custom exception raised when the model not support + """ + description = "Model Currently Not Support" diff --git a/model-providers/chatchat_model_providers/extensions/provider_configuration.py b/model-providers/chatchat_model_providers/extensions/provider_configuration.py new file mode 100644 index 00000000..0b05635c --- /dev/null +++ b/model-providers/chatchat_model_providers/extensions/provider_configuration.py @@ -0,0 +1,320 @@ +import datetime +import json +import logging +from collections.abc import Iterator +from json import JSONDecodeError +from typing import Optional + +from pydantic import BaseModel + +from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity +from model_providers.core.entities.provider_entities import CustomConfiguration +from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType +from model_providers.core.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from model_providers.core.model_runtime.model_providers import model_provider_factory +from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel +from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class ProviderConfiguration(BaseModel): + """ + Model class for provider configuration. + """ + provider: ProviderEntity + custom_configuration: CustomConfiguration + + def __init__(self, **data): + super().__init__(**data) + + def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: + """ + Get current credentials. + + :param model_type: model type + :param model: model name + :return: + """ + if self.custom_configuration.models: + for model_configuration in self.custom_configuration.models: + if model_configuration.model_type == model_type and model_configuration.model == model: + return model_configuration.credentials + + if self.custom_configuration.provider: + return self.custom_configuration.provider.credentials + else: + return None + + def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: + """ + Get custom credentials. + + :param obfuscated: obfuscated secret data in credentials + :return: + """ + if self.custom_configuration.provider is None: + return None + + credentials = self.custom_configuration.provider.credentials + if not obfuscated: + return credentials + + # Obfuscate provider credentials + copy_credentials = credentials.copy() + return copy_credentials + + def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ + -> Optional[dict]: + """ + Get custom model credentials. + + :param model_type: model type + :param model: model name + :param obfuscated: obfuscated secret data in credentials + :return: + """ + if not self.custom_configuration.models: + return None + + for model_configuration in self.custom_configuration.models: + if model_configuration.model_type == model_type and model_configuration.model == model: + credentials = model_configuration.credentials + if not obfuscated: + return credentials + copy_credentials = credentials.copy() + # Obfuscate credentials + return copy_credentials + + return None + + def get_provider_instance(self) -> ModelProvider: + """ + Get provider instance. + :return: + """ + return model_provider_factory.get_provider_instance(self.provider.provider) + + def get_model_type_instance(self, model_type: ModelType) -> AIModel: + """ + Get current model type instance. + + :param model_type: model type + :return: + """ + # Get provider instance + provider_instance = self.get_provider_instance() + + # Get model instance of LLM + return provider_instance.get_model_instance(model_type) + + def get_provider_model(self, model_type: ModelType, + model: str, + only_active: bool = False) -> Optional[ModelWithProviderEntity]: + """ + Get provider model. + :param model_type: model type + :param model: model name + :param only_active: return active model only + :return: + """ + provider_models = self.get_provider_models(model_type, only_active) + + for provider_model in provider_models: + if provider_model.model == model: + return provider_model + + return None + + def get_provider_models(self, model_type: Optional[ModelType] = None, + only_active: bool = False) -> list[ModelWithProviderEntity]: + """ + Get provider models. + :param model_type: model type + :param only_active: only active models + :return: + """ + provider_instance = self.get_provider_instance() + + model_types = [] + if model_type: + model_types.append(model_type) + else: + model_types = provider_instance.get_provider_schema().supported_model_types + + provider_models = self._get_custom_provider_models( + model_types=model_types, + provider_instance=provider_instance + ) + if only_active: + provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] + + # resort provider_models + return sorted(provider_models, key=lambda x: x.model_type.value) + + def _get_custom_provider_models(self, + model_types: list[ModelType], + provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: + """ + Get custom provider models. + + :param model_types: model types + :param provider_instance: provider instance + :return: + """ + provider_models = [] + + credentials = None + if self.custom_configuration.provider: + credentials = self.custom_configuration.provider.credentials + + for model_type in model_types: + if model_type not in self.provider.supported_model_types: + continue + + models = provider_instance.models(model_type) + for m in models: + provider_models.append( + ModelWithProviderEntity( + model=m.model, + label=m.label, + model_type=m.model_type, + features=m.features, + fetch_from=m.fetch_from, + model_properties=m.model_properties, + deprecated=m.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE + ) + ) + + # custom models + for model_configuration in self.custom_configuration.models: + if model_configuration.model_type not in model_types: + continue + + try: + custom_model_schema = ( + provider_instance.get_model_instance(model_configuration.model_type) + .get_customizable_model_schema_from_credentials( + model_configuration.model, + model_configuration.credentials + ) + ) + except Exception as ex: + logger.warning(f'get custom model schema failed, {ex}') + continue + + if not custom_model_schema: + continue + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=custom_model_schema.fetch_from, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=ModelStatus.ACTIVE + ) + ) + + return provider_models + + +class ProviderConfigurations(BaseModel): + """ + Model class for provider configuration dict. + """ + configurations: dict[str, ProviderConfiguration] = {} + + def __init__(self): + super().__init__() + + def get_models(self, + provider: Optional[str] = None, + model_type: Optional[ModelType] = None, + only_active: bool = False) \ + -> list[ModelWithProviderEntity]: + """ + Get available models. + + If preferred provider type is `system`: + Get the current **system mode** if provider supported, + if all system modes are not available (no quota), it is considered to be the **custom credential mode**. + If there is no model configured in custom mode, it is treated as no_configure. + system > custom > no_configure + + If preferred provider type is `custom`: + If custom credentials are configured, it is treated as custom mode. + Otherwise, get the current **system mode** if supported, + If all system modes are not available (no quota), it is treated as no_configure. + custom > system > no_configure + + If real mode is `system`, use system credentials to get models, + paid quotas > provider free quotas > system free quotas + include pre-defined models (exclude GPT-4, status marked as `no_permission`). + If real mode is `custom`, use workspace custom credentials to get models, + include pre-defined models, custom models(manual append). + If real mode is `no_configure`, only return pre-defined models from `model runtime`. + (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) + model status marked as `active` is available. + + :param provider: provider name + :param model_type: model type + :param only_active: only active models + :return: + """ + all_models = [] + for provider_configuration in self.values(): + if provider and provider_configuration.provider.provider != provider: + continue + + all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) + + return all_models + + def to_list(self) -> list[ProviderConfiguration]: + """ + Convert to list. + + :return: + """ + return list(self.values()) + + def __getitem__(self, key): + return self.configurations[key] + + def __setitem__(self, key, value): + self.configurations[key] = value + + def __iter__(self): + return iter(self.configurations) + + def values(self) -> Iterator[ProviderConfiguration]: + return self.configurations.values() + + def get(self, key, default=None): + return self.configurations.get(key, default) + + +class ProviderModelBundle(BaseModel): + """ + Provider model bundle. + """ + configuration: ProviderConfiguration + provider_instance: ModelProvider + model_type_instance: AIModel + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True diff --git a/model-providers/chatchat_model_providers/extensions/provider_entities.py b/model-providers/chatchat_model_providers/extensions/provider_entities.py new file mode 100644 index 00000000..715cf899 --- /dev/null +++ b/model-providers/chatchat_model_providers/extensions/provider_entities.py @@ -0,0 +1,37 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from model_providers.core.model_runtime.entities.model_entities import ModelType + + +class RestrictModel(BaseModel): + model: str + base_model_name: Optional[str] = None + model_type: ModelType + + + +class CustomProviderConfiguration(BaseModel): + """ + Model class for provider custom configuration. + """ + credentials: dict + + +class CustomModelConfiguration(BaseModel): + """ + Model class for provider custom model configuration. + """ + model: str + model_type: ModelType + credentials: dict + + +class CustomConfiguration(BaseModel): + """ + Model class for provider custom configuration. + """ + provider: Optional[CustomProviderConfiguration] = None + models: list[CustomModelConfiguration] = [] From 43a19d9b66825a340fc28322f0f7d90bf3fcd3b6 Mon Sep 17 00:00:00 2001 From: khazic Date: Tue, 26 Mar 2024 14:58:50 +0800 Subject: [PATCH 2/3] 20240326 --- .../core/entities/provider_configuration.py | 528 +----------------- .../core/entities/provider_entities.py | 39 +- .../extensions/provider_configuration.py | 320 ----------- .../extensions/provider_entities.py | 37 -- 4 files changed, 26 insertions(+), 898 deletions(-) delete mode 100644 model-providers/chatchat_model_providers/extensions/provider_configuration.py delete mode 100644 model-providers/chatchat_model_providers/extensions/provider_entities.py diff --git a/model-providers/chatchat_model_providers/core/entities/provider_configuration.py b/model-providers/chatchat_model_providers/core/entities/provider_configuration.py index 8f9c203b..0b05635c 100644 --- a/model-providers/chatchat_model_providers/core/entities/provider_configuration.py +++ b/model-providers/chatchat_model_providers/core/entities/provider_configuration.py @@ -7,53 +7,32 @@ from typing import Optional from pydantic import BaseModel -from chatchat_model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity -from chatchat_model_providers.core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus -from chatchat_model_providers.core.helper import encrypter -from chatchat_model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from chatchat_model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType -from chatchat_model_providers.core.model_runtime.entities.provider_entities import ( +from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity +from model_providers.core.entities.provider_entities import CustomConfiguration +from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType +from model_providers.core.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory -from chatchat_model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel -from chatchat_model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider -from chatchat_model_providers.extensions.ext_database import db -from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider +from model_providers.core.model_runtime.model_providers import model_provider_factory +from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel +from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) -original_provider_configurate_methods = {} - class ProviderConfiguration(BaseModel): """ Model class for provider configuration. """ - tenant_id: str provider: ProviderEntity - preferred_provider_type: ProviderType - using_provider_type: ProviderType - system_configuration: SystemConfiguration custom_configuration: CustomConfiguration def __init__(self, **data): super().__init__(**data) - if self.provider.provider not in original_provider_configurate_methods: - original_provider_configurate_methods[self.provider.provider] = [] - for configurate_method in self.provider.configurate_methods: - original_provider_configurate_methods[self.provider.provider].append(configurate_method) - - if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: - if (any([len(quota_configuration.restrict_models) > 0 - for quota_configuration in self.system_configuration.quota_configurations]) - and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): - self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) - def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: """ Get current credentials. @@ -62,58 +41,15 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - if self.using_provider_type == ProviderType.SYSTEM: - restrict_models = [] - for quota_configuration in self.system_configuration.quota_configurations: - if self.system_configuration.current_quota_type != quota_configuration.quota_type: - continue + if self.custom_configuration.models: + for model_configuration in self.custom_configuration.models: + if model_configuration.model_type == model_type and model_configuration.model == model: + return model_configuration.credentials - restrict_models = quota_configuration.restrict_models - - copy_credentials = self.system_configuration.credentials.copy() - if restrict_models: - for restrict_model in restrict_models: - if (restrict_model.model_type == model_type - and restrict_model.model == model - and restrict_model.base_model_name): - copy_credentials['base_model_name'] = restrict_model.base_model_name - - return copy_credentials + if self.custom_configuration.provider: + return self.custom_configuration.provider.credentials else: - if self.custom_configuration.models: - for model_configuration in self.custom_configuration.models: - if model_configuration.model_type == model_type and model_configuration.model == model: - return model_configuration.credentials - - if self.custom_configuration.provider: - return self.custom_configuration.provider.credentials - else: - return None - - def get_system_configuration_status(self) -> SystemConfigurationStatus: - """ - Get system configuration status. - :return: - """ - if self.system_configuration.enabled is False: - return SystemConfigurationStatus.UNSUPPORTED - - current_quota_type = self.system_configuration.current_quota_type - current_quota_configuration = next( - (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), - None - ) - - return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \ - SystemConfigurationStatus.QUOTA_EXCEEDED - - def is_custom_configuration_available(self) -> bool: - """ - Check custom configuration available. - :return: - """ - return (self.custom_configuration.provider is not None - or len(self.custom_configuration.models) > 0) + return None def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: """ @@ -129,130 +65,9 @@ class ProviderConfiguration(BaseModel): if not obfuscated: return credentials - # Obfuscate credentials - return self._obfuscated_credentials( - credentials=credentials, - credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] - ) - - def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: - """ - Validate custom credentials. - :param credentials: provider credentials - :return: - """ - # get provider - provider_record = db.session.query(Provider) \ - .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() - - # Get provider credential secret variables - provider_credential_secret_variables = self._extract_secret_variables( - self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] - ) - - if provider_record: - try: - # fix origin data - if provider_record.encrypted_config: - if not provider_record.encrypted_config.startswith("{"): - original_credentials = { - "openai_api_key": provider_record.encrypted_config - } - else: - original_credentials = json.loads(provider_record.encrypted_config) - else: - original_credentials = {} - except JSONDecodeError: - original_credentials = {} - - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == '[__HIDDEN__]' and key in original_credentials: - credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - - credentials = model_provider_factory.provider_credentials_validate( - self.provider.provider, - credentials - ) - - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(self.tenant_id, value) - - return provider_record, credentials - - def add_or_update_custom_credentials(self, credentials: dict) -> None: - """ - Add or update custom provider credentials. - :param credentials: - :return: - """ - # validate custom provider config - provider_record, credentials = self.custom_credentials_validate(credentials) - - # save provider - # Note: Do not switch the preferred provider, which allows users to use quotas first - if provider_record: - provider_record.encrypted_config = json.dumps(credentials) - provider_record.is_valid = True - provider_record.updated_at = datetime.datetime.utcnow() - db.session.commit() - else: - provider_record = Provider( - tenant_id=self.tenant_id, - provider_name=self.provider.provider, - provider_type=ProviderType.CUSTOM.value, - encrypted_config=json.dumps(credentials), - is_valid=True - ) - db.session.add(provider_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER - ) - - provider_model_credentials_cache.delete() - - self.switch_preferred_provider_type(ProviderType.CUSTOM) - - def delete_custom_credentials(self) -> None: - """ - Delete custom provider credentials. - :return: - """ - # get provider - provider_record = db.session.query(Provider) \ - .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() - - # delete provider - if provider_record: - self.switch_preferred_provider_type(ProviderType.SYSTEM) - - db.session.delete(provider_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER - ) - - provider_model_credentials_cache.delete() + # Obfuscate provider credentials + copy_credentials = credentials.copy() + return copy_credentials def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ -> Optional[dict]: @@ -272,136 +87,12 @@ class ProviderConfiguration(BaseModel): credentials = model_configuration.credentials if not obfuscated: return credentials - + copy_credentials = credentials.copy() # Obfuscate credentials - return self._obfuscated_credentials( - credentials=credentials, - credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] - ) + return copy_credentials return None - def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ - -> tuple[ProviderModel, dict]: - """ - Validate custom model credentials. - - :param model_type: model type - :param model: model name - :param credentials: model credentials - :return: - """ - # get provider model - provider_model_record = db.session.query(ProviderModel) \ - .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() - - # Get provider credential secret variables - provider_credential_secret_variables = self._extract_secret_variables( - self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] - ) - - if provider_model_record: - try: - original_credentials = json.loads( - provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} - except JSONDecodeError: - original_credentials = {} - - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == '[__HIDDEN__]' and key in original_credentials: - credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) - - credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, - model_type=model_type, - model=model, - credentials=credentials - ) - - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - credentials[key] = encrypter.encrypt_token(self.tenant_id, value) - - return provider_model_record, credentials - - def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: - """ - Add or update custom model credentials. - - :param model_type: model type - :param model: model name - :param credentials: model credentials - :return: - """ - # validate custom model config - provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) - - # save provider model - # Note: Do not switch the preferred provider, which allows users to use quotas first - if provider_model_record: - provider_model_record.encrypted_config = json.dumps(credentials) - provider_model_record.is_valid = True - provider_model_record.updated_at = datetime.datetime.utcnow() - db.session.commit() - else: - provider_model_record = ProviderModel( - tenant_id=self.tenant_id, - provider_name=self.provider.provider, - model_name=model, - model_type=model_type.to_origin_model_type(), - encrypted_config=json.dumps(credentials), - is_valid=True - ) - db.session.add(provider_model_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL - ) - - provider_model_credentials_cache.delete() - - def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: - """ - Delete custom model credentials. - :param model_type: model type - :param model: model name - :return: - """ - # get provider model - provider_model_record = db.session.query(ProviderModel) \ - .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() - - # delete provider model - if provider_model_record: - db.session.delete(provider_model_record) - db.session.commit() - - provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL - ) - - provider_model_credentials_cache.delete() - def get_provider_instance(self) -> ModelProvider: """ Get provider instance. @@ -422,72 +113,6 @@ class ProviderConfiguration(BaseModel): # Get model instance of LLM return provider_instance.get_model_instance(model_type) - def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: - """ - Switch preferred provider type. - :param provider_type: - :return: - """ - if provider_type == self.preferred_provider_type: - return - - if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: - return - - # get preferred provider - preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ - .filter( - TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name == self.provider.provider - ).first() - - if preferred_model_provider: - preferred_model_provider.preferred_provider_type = provider_type.value - else: - preferred_model_provider = TenantPreferredModelProvider( - tenant_id=self.tenant_id, - provider_name=self.provider.provider, - preferred_provider_type=provider_type.value - ) - db.session.add(preferred_model_provider) - - db.session.commit() - - def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: - """ - Extract secret input form variables. - - :param credential_form_schemas: - :return: - """ - secret_input_form_variables = [] - for credential_form_schema in credential_form_schemas: - if credential_form_schema.type == FormType.SECRET_INPUT: - secret_input_form_variables.append(credential_form_schema.variable) - - return secret_input_form_variables - - def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: - """ - Obfuscated credentials. - - :param credentials: credentials - :param credential_form_schemas: credential form schemas - :return: - """ - # Get provider credential secret variables - credential_secret_variables = self._extract_secret_variables( - credential_form_schemas - ) - - # Obfuscate provider credentials - copy_credentials = credentials.copy() - for key, value in copy_credentials.items(): - if key in credential_secret_variables: - copy_credentials[key] = encrypter.obfuscated_token(value) - - return copy_credentials - def get_provider_model(self, model_type: ModelType, model: str, only_active: bool = False) -> Optional[ModelWithProviderEntity]: @@ -522,118 +147,16 @@ class ProviderConfiguration(BaseModel): else: model_types = provider_instance.get_provider_schema().supported_model_types - if self.using_provider_type == ProviderType.SYSTEM: - provider_models = self._get_system_provider_models( - model_types=model_types, - provider_instance=provider_instance - ) - else: - provider_models = self._get_custom_provider_models( - model_types=model_types, - provider_instance=provider_instance - ) - + provider_models = self._get_custom_provider_models( + model_types=model_types, + provider_instance=provider_instance + ) if only_active: provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] # resort provider_models return sorted(provider_models, key=lambda x: x.model_type.value) - def _get_system_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: - """ - Get system provider models. - - :param model_types: model types - :param provider_instance: provider instance - :return: - """ - provider_models = [] - for model_type in model_types: - provider_models.extend( - [ - ModelWithProviderEntity( - model=m.model, - label=m.label, - model_type=m.model_type, - features=m.features, - fetch_from=m.fetch_from, - model_properties=m.model_properties, - deprecated=m.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE - ) - for m in provider_instance.models(model_type) - ] - ) - - if self.provider.provider not in original_provider_configurate_methods: - original_provider_configurate_methods[self.provider.provider] = [] - for configurate_method in provider_instance.get_provider_schema().configurate_methods: - original_provider_configurate_methods[self.provider.provider].append(configurate_method) - - should_use_custom_model = False - if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: - should_use_custom_model = True - - for quota_configuration in self.system_configuration.quota_configurations: - if self.system_configuration.current_quota_type != quota_configuration.quota_type: - continue - - restrict_models = quota_configuration.restrict_models - if len(restrict_models) == 0: - break - - if should_use_custom_model: - if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: - # only customizable model - for restrict_model in restrict_models: - copy_credentials = self.system_configuration.credentials.copy() - if restrict_model.base_model_name: - copy_credentials['base_model_name'] = restrict_model.base_model_name - - try: - custom_model_schema = ( - provider_instance.get_model_instance(restrict_model.model_type) - .get_customizable_model_schema_from_credentials( - restrict_model.model, - copy_credentials - ) - ) - except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') - continue - - if not custom_model_schema: - continue - - if custom_model_schema.model_type not in model_types: - continue - - provider_models.append( - ModelWithProviderEntity( - model=custom_model_schema.model, - label=custom_model_schema.label, - model_type=custom_model_schema.model_type, - features=custom_model_schema.features, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties=custom_model_schema.model_properties, - deprecated=custom_model_schema.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE - ) - ) - - # if llm name not in restricted llm list, remove it - restrict_model_names = [rm.model for rm in restrict_models] - for m in provider_models: - if m.model_type == ModelType.LLM and m.model not in restrict_model_names: - m.status = ModelStatus.NO_PERMISSION - elif not quota_configuration.is_valid: - m.status = ModelStatus.QUOTA_EXCEEDED - return provider_models - def _get_custom_provider_models(self, model_types: list[ModelType], provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: @@ -711,11 +234,10 @@ class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ - tenant_id: str configurations: dict[str, ProviderConfiguration] = {} - def __init__(self, tenant_id: str): - super().__init__(tenant_id=tenant_id) + def __init__(self): + super().__init__() def get_models(self, provider: Optional[str] = None, diff --git a/model-providers/chatchat_model_providers/core/entities/provider_entities.py b/model-providers/chatchat_model_providers/core/entities/provider_entities.py index fb39255a..715cf899 100644 --- a/model-providers/chatchat_model_providers/core/entities/provider_entities.py +++ b/model-providers/chatchat_model_providers/core/entities/provider_entities.py @@ -3,23 +3,7 @@ from typing import Optional from pydantic import BaseModel -from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType -from models.provider import ProviderQuotaType - - -class QuotaUnit(Enum): - TIMES = 'times' - TOKENS = 'tokens' - CREDITS = 'credits' - - -class SystemConfigurationStatus(Enum): - """ - Enum class for system configuration status. - """ - ACTIVE = 'active' - QUOTA_EXCEEDED = 'quota-exceeded' - UNSUPPORTED = 'unsupported' +from model_providers.core.model_runtime.entities.model_entities import ModelType class RestrictModel(BaseModel): @@ -28,27 +12,6 @@ class RestrictModel(BaseModel): model_type: ModelType -class QuotaConfiguration(BaseModel): - """ - Model class for provider quota configuration. - """ - quota_type: ProviderQuotaType - quota_unit: QuotaUnit - quota_limit: int - quota_used: int - is_valid: bool - restrict_models: list[RestrictModel] = [] - - -class SystemConfiguration(BaseModel): - """ - Model class for provider system configuration. - """ - enabled: bool - current_quota_type: Optional[ProviderQuotaType] = None - quota_configurations: list[QuotaConfiguration] = [] - credentials: Optional[dict] = None - class CustomProviderConfiguration(BaseModel): """ diff --git a/model-providers/chatchat_model_providers/extensions/provider_configuration.py b/model-providers/chatchat_model_providers/extensions/provider_configuration.py deleted file mode 100644 index 0b05635c..00000000 --- a/model-providers/chatchat_model_providers/extensions/provider_configuration.py +++ /dev/null @@ -1,320 +0,0 @@ -import datetime -import json -import logging -from collections.abc import Iterator -from json import JSONDecodeError -from typing import Optional - -from pydantic import BaseModel - -from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity -from model_providers.core.entities.provider_entities import CustomConfiguration -from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType -from model_providers.core.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from model_providers.core.model_runtime.model_providers import model_provider_factory -from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider - -logger = logging.getLogger(__name__) - - -class ProviderConfiguration(BaseModel): - """ - Model class for provider configuration. - """ - provider: ProviderEntity - custom_configuration: CustomConfiguration - - def __init__(self, **data): - super().__init__(**data) - - def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: - """ - Get current credentials. - - :param model_type: model type - :param model: model name - :return: - """ - if self.custom_configuration.models: - for model_configuration in self.custom_configuration.models: - if model_configuration.model_type == model_type and model_configuration.model == model: - return model_configuration.credentials - - if self.custom_configuration.provider: - return self.custom_configuration.provider.credentials - else: - return None - - def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: - """ - Get custom credentials. - - :param obfuscated: obfuscated secret data in credentials - :return: - """ - if self.custom_configuration.provider is None: - return None - - credentials = self.custom_configuration.provider.credentials - if not obfuscated: - return credentials - - # Obfuscate provider credentials - copy_credentials = credentials.copy() - return copy_credentials - - def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ - -> Optional[dict]: - """ - Get custom model credentials. - - :param model_type: model type - :param model: model name - :param obfuscated: obfuscated secret data in credentials - :return: - """ - if not self.custom_configuration.models: - return None - - for model_configuration in self.custom_configuration.models: - if model_configuration.model_type == model_type and model_configuration.model == model: - credentials = model_configuration.credentials - if not obfuscated: - return credentials - copy_credentials = credentials.copy() - # Obfuscate credentials - return copy_credentials - - return None - - def get_provider_instance(self) -> ModelProvider: - """ - Get provider instance. - :return: - """ - return model_provider_factory.get_provider_instance(self.provider.provider) - - def get_model_type_instance(self, model_type: ModelType) -> AIModel: - """ - Get current model type instance. - - :param model_type: model type - :return: - """ - # Get provider instance - provider_instance = self.get_provider_instance() - - # Get model instance of LLM - return provider_instance.get_model_instance(model_type) - - def get_provider_model(self, model_type: ModelType, - model: str, - only_active: bool = False) -> Optional[ModelWithProviderEntity]: - """ - Get provider model. - :param model_type: model type - :param model: model name - :param only_active: return active model only - :return: - """ - provider_models = self.get_provider_models(model_type, only_active) - - for provider_model in provider_models: - if provider_model.model == model: - return provider_model - - return None - - def get_provider_models(self, model_type: Optional[ModelType] = None, - only_active: bool = False) -> list[ModelWithProviderEntity]: - """ - Get provider models. - :param model_type: model type - :param only_active: only active models - :return: - """ - provider_instance = self.get_provider_instance() - - model_types = [] - if model_type: - model_types.append(model_type) - else: - model_types = provider_instance.get_provider_schema().supported_model_types - - provider_models = self._get_custom_provider_models( - model_types=model_types, - provider_instance=provider_instance - ) - if only_active: - provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] - - # resort provider_models - return sorted(provider_models, key=lambda x: x.model_type.value) - - def _get_custom_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: - """ - Get custom provider models. - - :param model_types: model types - :param provider_instance: provider instance - :return: - """ - provider_models = [] - - credentials = None - if self.custom_configuration.provider: - credentials = self.custom_configuration.provider.credentials - - for model_type in model_types: - if model_type not in self.provider.supported_model_types: - continue - - models = provider_instance.models(model_type) - for m in models: - provider_models.append( - ModelWithProviderEntity( - model=m.model, - label=m.label, - model_type=m.model_type, - features=m.features, - fetch_from=m.fetch_from, - model_properties=m.model_properties, - deprecated=m.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE - ) - ) - - # custom models - for model_configuration in self.custom_configuration.models: - if model_configuration.model_type not in model_types: - continue - - try: - custom_model_schema = ( - provider_instance.get_model_instance(model_configuration.model_type) - .get_customizable_model_schema_from_credentials( - model_configuration.model, - model_configuration.credentials - ) - ) - except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') - continue - - if not custom_model_schema: - continue - - provider_models.append( - ModelWithProviderEntity( - model=custom_model_schema.model, - label=custom_model_schema.label, - model_type=custom_model_schema.model_type, - features=custom_model_schema.features, - fetch_from=custom_model_schema.fetch_from, - model_properties=custom_model_schema.model_properties, - deprecated=custom_model_schema.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=ModelStatus.ACTIVE - ) - ) - - return provider_models - - -class ProviderConfigurations(BaseModel): - """ - Model class for provider configuration dict. - """ - configurations: dict[str, ProviderConfiguration] = {} - - def __init__(self): - super().__init__() - - def get_models(self, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - only_active: bool = False) \ - -> list[ModelWithProviderEntity]: - """ - Get available models. - - If preferred provider type is `system`: - Get the current **system mode** if provider supported, - if all system modes are not available (no quota), it is considered to be the **custom credential mode**. - If there is no model configured in custom mode, it is treated as no_configure. - system > custom > no_configure - - If preferred provider type is `custom`: - If custom credentials are configured, it is treated as custom mode. - Otherwise, get the current **system mode** if supported, - If all system modes are not available (no quota), it is treated as no_configure. - custom > system > no_configure - - If real mode is `system`, use system credentials to get models, - paid quotas > provider free quotas > system free quotas - include pre-defined models (exclude GPT-4, status marked as `no_permission`). - If real mode is `custom`, use workspace custom credentials to get models, - include pre-defined models, custom models(manual append). - If real mode is `no_configure`, only return pre-defined models from `model runtime`. - (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) - model status marked as `active` is available. - - :param provider: provider name - :param model_type: model type - :param only_active: only active models - :return: - """ - all_models = [] - for provider_configuration in self.values(): - if provider and provider_configuration.provider.provider != provider: - continue - - all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) - - return all_models - - def to_list(self) -> list[ProviderConfiguration]: - """ - Convert to list. - - :return: - """ - return list(self.values()) - - def __getitem__(self, key): - return self.configurations[key] - - def __setitem__(self, key, value): - self.configurations[key] = value - - def __iter__(self): - return iter(self.configurations) - - def values(self) -> Iterator[ProviderConfiguration]: - return self.configurations.values() - - def get(self, key, default=None): - return self.configurations.get(key, default) - - -class ProviderModelBundle(BaseModel): - """ - Provider model bundle. - """ - configuration: ProviderConfiguration - provider_instance: ModelProvider - model_type_instance: AIModel - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True diff --git a/model-providers/chatchat_model_providers/extensions/provider_entities.py b/model-providers/chatchat_model_providers/extensions/provider_entities.py deleted file mode 100644 index 715cf899..00000000 --- a/model-providers/chatchat_model_providers/extensions/provider_entities.py +++ /dev/null @@ -1,37 +0,0 @@ -from enum import Enum -from typing import Optional - -from pydantic import BaseModel - -from model_providers.core.model_runtime.entities.model_entities import ModelType - - -class RestrictModel(BaseModel): - model: str - base_model_name: Optional[str] = None - model_type: ModelType - - - -class CustomProviderConfiguration(BaseModel): - """ - Model class for provider custom configuration. - """ - credentials: dict - - -class CustomModelConfiguration(BaseModel): - """ - Model class for provider custom model configuration. - """ - model: str - model_type: ModelType - credentials: dict - - -class CustomConfiguration(BaseModel): - """ - Model class for provider custom configuration. - """ - provider: Optional[CustomProviderConfiguration] = None - models: list[CustomModelConfiguration] = [] From 5b9028684362317378b5d798a526d55e05d7a6c7 Mon Sep 17 00:00:00 2001 From: khazic Date: Tue, 26 Mar 2024 16:00:40 +0800 Subject: [PATCH 3/3] qqqq --- .../chatchat_model_providers/__main__.py | 2 +- .../core/entities/provider_configuration.py | 14 ++++----- .../core/entities/provider_entities.py | 2 +- .../core/model_manager.py | 30 +++++++++---------- .../core/provider_manager.py | 12 ++++---- 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/model-providers/chatchat_model_providers/__main__.py b/model-providers/chatchat_model_providers/__main__.py index 318b7e5f..a4ad57e7 100644 --- a/model-providers/chatchat_model_providers/__main__.py +++ b/model-providers/chatchat_model_providers/__main__.py @@ -11,7 +11,7 @@ if __name__ == '__main__': provider_manager = ModelManager( provider_name_to_provider_records_dict={ 'openai': { - 'openai_api_key': "sk- ", + 'openai_api_key': "sk-4M9LYF", } }, provider_name_to_provider_model_records_dict={} diff --git a/model-providers/chatchat_model_providers/core/entities/provider_configuration.py b/model-providers/chatchat_model_providers/core/entities/provider_configuration.py index 0b05635c..ad8326bf 100644 --- a/model-providers/chatchat_model_providers/core/entities/provider_configuration.py +++ b/model-providers/chatchat_model_providers/core/entities/provider_configuration.py @@ -7,18 +7,18 @@ from typing import Optional from pydantic import BaseModel -from model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity -from model_providers.core.entities.provider_entities import CustomConfiguration -from model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType -from model_providers.core.model_runtime.entities.provider_entities import ( +from chatchat_model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity +from chatchat_model_providers.core.entities.provider_entities import CustomConfiguration +from chatchat_model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType +from chatchat_model_providers.core.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from model_providers.core.model_runtime.model_providers import model_provider_factory -from model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel -from model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider +from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory +from chatchat_model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel +from chatchat_model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/model-providers/chatchat_model_providers/core/entities/provider_entities.py b/model-providers/chatchat_model_providers/core/entities/provider_entities.py index 715cf899..4b2dabb1 100644 --- a/model-providers/chatchat_model_providers/core/entities/provider_entities.py +++ b/model-providers/chatchat_model_providers/core/entities/provider_entities.py @@ -3,7 +3,7 @@ from typing import Optional from pydantic import BaseModel -from model_providers.core.model_runtime.entities.model_entities import ModelType +from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType class RestrictModel(BaseModel): diff --git a/model-providers/chatchat_model_providers/core/model_manager.py b/model-providers/chatchat_model_providers/core/model_manager.py index f7e9be10..d61d2916 100644 --- a/model-providers/chatchat_model_providers/core/model_manager.py +++ b/model-providers/chatchat_model_providers/core/model_manager.py @@ -1,21 +1,21 @@ from collections.abc import Generator from typing import IO, Optional, Union, cast -from model_providers.core.entities.provider_configuration import ProviderModelBundle -from model_providers.core.errors.error import ProviderTokenNotInitError -from model_providers.core.model_runtime.callbacks.base_callback import Callback -from model_providers.core.model_runtime.entities.llm_entities import LLMResult -from model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.entities.rerank_entities import RerankResult -from model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult -from model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel -from model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel -from model_providers.core.provider_manager import ProviderManager +from chatchat_model_providers.core.entities.provider_configuration import ProviderModelBundle +from chatchat_model_providers.errors.error import ProviderTokenNotInitError +from chatchat_model_providers.core.model_runtime.callbacks.base_callback import Callback +from chatchat_model_providers.core.model_runtime.entities.llm_entities import LLMResult +from chatchat_model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType +from chatchat_model_providers.core.model_runtime.entities.rerank_entities import RerankResult +from chatchat_model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from chatchat_model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from chatchat_model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel +from chatchat_model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel +from chatchat_model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from chatchat_model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from chatchat_model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel +from chatchat_model_providers.core.provider_manager import ProviderManager def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: diff --git a/model-providers/chatchat_model_providers/core/provider_manager.py b/model-providers/chatchat_model_providers/core/provider_manager.py index a2016963..d18e9950 100644 --- a/model-providers/chatchat_model_providers/core/provider_manager.py +++ b/model-providers/chatchat_model_providers/core/provider_manager.py @@ -5,21 +5,21 @@ from typing import Optional from sqlalchemy.exc import IntegrityError -from model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity -from model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, \ +from chatchat_model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity +from chatchat_model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, \ ProviderModelBundle -from model_providers.core.entities.provider_entities import ( +from chatchat_model_providers.core.entities.provider_entities import ( CustomConfiguration, CustomModelConfiguration, CustomProviderConfiguration, ) -from model_providers.core.model_runtime.entities.model_entities import ModelType -from model_providers.core.model_runtime.entities.provider_entities import ( +from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType +from chatchat_model_providers.core.model_runtime.entities.provider_entities import ( CredentialFormSchema, FormType, ProviderEntity, ) -from model_providers.core.model_runtime.model_providers import model_provider_factory +from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory class ProviderManager: