mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-28 01:33:17 +08:00
commit
1e96d69945
@ -1,37 +1,28 @@
|
||||
import os
|
||||
from typing import cast, Generator
|
||||
|
||||
from chatchat_model_providers.core.model_manager import ModelManager
|
||||
from chatchat_model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
|
||||
from chatchat_model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
|
||||
from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 基于配置管理器创建的模型实例
|
||||
# provider_manager = ProviderManager()
|
||||
|
||||
provider_configurations = ProviderConfigurations(
|
||||
tenant_id=tenant_id
|
||||
provider_manager = ModelManager(
|
||||
provider_name_to_provider_records_dict={
|
||||
'openai': {
|
||||
'openai_api_key': "sk-4M9LYF",
|
||||
}
|
||||
},
|
||||
provider_name_to_provider_model_records_dict={}
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# model_instance = ModelInstance(
|
||||
# provider_model_bundle=provider_model_bundle,
|
||||
# model=model_config.model,
|
||||
# )
|
||||
# 直接通过模型加载器创建的模型实例
|
||||
from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory
|
||||
model_provider_factory.get_providers(provider_name='openai')
|
||||
provider_instance = model_provider_factory.get_provider_instance('openai')
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
print(model_type_instance)
|
||||
response = model_type_instance.invoke(
|
||||
model='gpt-4',
|
||||
credentials={
|
||||
'openai_api_key': "sk-",
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
},
|
||||
# Invoke model
|
||||
model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4')
|
||||
|
||||
response = model_instance.invoke_llm(
|
||||
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
|
||||
@ -8,9 +8,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from chatchat_model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from chatchat_model_providers.core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
|
||||
from chatchat_model_providers.core.helper import encrypter
|
||||
from chatchat_model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from chatchat_model_providers.core.entities.provider_entities import CustomConfiguration
|
||||
from chatchat_model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
||||
from chatchat_model_providers.core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
@ -21,39 +19,20 @@ from chatchat_model_providers.core.model_runtime.entities.provider_entities impo
|
||||
from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory
|
||||
from chatchat_model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from chatchat_model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from chatchat_model_providers.extensions.ext_database import db
|
||||
from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
original_provider_configurate_methods = {}
|
||||
|
||||
|
||||
class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider configuration.
|
||||
"""
|
||||
tenant_id: str
|
||||
provider: ProviderEntity
|
||||
preferred_provider_type: ProviderType
|
||||
using_provider_type: ProviderType
|
||||
system_configuration: SystemConfiguration
|
||||
custom_configuration: CustomConfiguration
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in self.provider.configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if (any([len(quota_configuration.restrict_models) > 0
|
||||
for quota_configuration in self.system_configuration.quota_configurations])
|
||||
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
|
||||
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
||||
"""
|
||||
Get current credentials.
|
||||
@ -62,58 +41,15 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
restrict_models = []
|
||||
for quota_configuration in self.system_configuration.quota_configurations:
|
||||
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
||||
continue
|
||||
if self.custom_configuration.models:
|
||||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type == model_type and model_configuration.model == model:
|
||||
return model_configuration.credentials
|
||||
|
||||
restrict_models = quota_configuration.restrict_models
|
||||
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_models:
|
||||
for restrict_model in restrict_models:
|
||||
if (restrict_model.model_type == model_type
|
||||
and restrict_model.model == model
|
||||
and restrict_model.base_model_name):
|
||||
copy_credentials['base_model_name'] = restrict_model.base_model_name
|
||||
|
||||
return copy_credentials
|
||||
if self.custom_configuration.provider:
|
||||
return self.custom_configuration.provider.credentials
|
||||
else:
|
||||
if self.custom_configuration.models:
|
||||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type == model_type and model_configuration.model == model:
|
||||
return model_configuration.credentials
|
||||
|
||||
if self.custom_configuration.provider:
|
||||
return self.custom_configuration.provider.credentials
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
||||
"""
|
||||
Get system configuration status.
|
||||
:return:
|
||||
"""
|
||||
if self.system_configuration.enabled is False:
|
||||
return SystemConfigurationStatus.UNSUPPORTED
|
||||
|
||||
current_quota_type = self.system_configuration.current_quota_type
|
||||
current_quota_configuration = next(
|
||||
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
|
||||
None
|
||||
)
|
||||
|
||||
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
|
||||
SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
|
||||
def is_custom_configuration_available(self) -> bool:
|
||||
"""
|
||||
Check custom configuration available.
|
||||
:return:
|
||||
"""
|
||||
return (self.custom_configuration.provider is not None
|
||||
or len(self.custom_configuration.models) > 0)
|
||||
return None
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
"""
|
||||
@ -129,130 +65,9 @@ class ProviderConfiguration(BaseModel):
|
||||
if not obfuscated:
|
||||
return credentials
|
||||
|
||||
# Obfuscate credentials
|
||||
return self._obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
|
||||
if provider_record:
|
||||
try:
|
||||
# fix origin data
|
||||
if provider_record.encrypted_config:
|
||||
if not provider_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {
|
||||
"openai_api_key": provider_record.encrypted_config
|
||||
}
|
||||
else:
|
||||
original_credentials = json.loads(provider_record.encrypted_config)
|
||||
else:
|
||||
original_credentials = {}
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# encrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
credentials = model_provider_factory.provider_credentials_validate(
|
||||
self.provider.provider,
|
||||
credentials
|
||||
)
|
||||
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return provider_record, credentials
|
||||
|
||||
def add_or_update_custom_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Add or update custom provider credentials.
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
# validate custom provider config
|
||||
provider_record, credentials = self.custom_credentials_validate(credentials)
|
||||
|
||||
# save provider
|
||||
# Note: Do not switch the preferred provider, which allows users to use quotas first
|
||||
if provider_record:
|
||||
provider_record.encrypted_config = json.dumps(credentials)
|
||||
provider_record.is_valid = True
|
||||
provider_record.updated_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
self.switch_preferred_provider_type(ProviderType.CUSTOM)
|
||||
|
||||
def delete_custom_credentials(self) -> None:
|
||||
"""
|
||||
Delete custom provider credentials.
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value
|
||||
).first()
|
||||
|
||||
# delete provider
|
||||
if provider_record:
|
||||
self.switch_preferred_provider_type(ProviderType.SYSTEM)
|
||||
|
||||
db.session.delete(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = credentials.copy()
|
||||
return copy_credentials
|
||||
|
||||
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
||||
-> Optional[dict]:
|
||||
@ -272,136 +87,12 @@ class ProviderConfiguration(BaseModel):
|
||||
credentials = model_configuration.credentials
|
||||
if not obfuscated:
|
||||
return credentials
|
||||
|
||||
copy_credentials = credentials.copy()
|
||||
# Obfuscate credentials
|
||||
return self._obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
)
|
||||
return copy_credentials
|
||||
|
||||
return None
|
||||
|
||||
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||
-> tuple[ProviderModel, dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
)
|
||||
|
||||
if provider_model_record:
|
||||
try:
|
||||
original_credentials = json.loads(
|
||||
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# decrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == '[__HIDDEN__]' and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
||||
|
||||
credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return provider_model_record, credentials
|
||||
|
||||
def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Add or update custom model credentials.
|
||||
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
# validate custom model config
|
||||
provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
|
||||
|
||||
# save provider model
|
||||
# Note: Do not switch the preferred provider, which allows users to use quotas first
|
||||
if provider_model_record:
|
||||
provider_model_record.encrypted_config = json.dumps(credentials)
|
||||
provider_model_record.is_valid = True
|
||||
provider_model_record.updated_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
else:
|
||||
provider_model_record = ProviderModel(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
encrypted_config=json.dumps(credentials),
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_model_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
|
||||
"""
|
||||
Delete custom model credentials.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
# get provider model
|
||||
provider_model_record = db.session.query(ProviderModel) \
|
||||
.filter(
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name == self.provider.provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type()
|
||||
).first()
|
||||
|
||||
# delete provider model
|
||||
if provider_model_record:
|
||||
db.session.delete(provider_model_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def get_provider_instance(self) -> ModelProvider:
|
||||
"""
|
||||
Get provider instance.
|
||||
@ -422,72 +113,6 @@ class ProviderConfiguration(BaseModel):
|
||||
# Get model instance of LLM
|
||||
return provider_instance.get_model_instance(model_type)
|
||||
|
||||
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
|
||||
"""
|
||||
Switch preferred provider type.
|
||||
:param provider_type:
|
||||
:return:
|
||||
"""
|
||||
if provider_type == self.preferred_provider_type:
|
||||
return
|
||||
|
||||
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
|
||||
return
|
||||
|
||||
# get preferred provider
|
||||
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
|
||||
.filter(
|
||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||
TenantPreferredModelProvider.provider_name == self.provider.provider
|
||||
).first()
|
||||
|
||||
if preferred_model_provider:
|
||||
preferred_model_provider.preferred_provider_type = provider_type.value
|
||||
else:
|
||||
preferred_model_provider = TenantPreferredModelProvider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
preferred_provider_type=provider_type.value
|
||||
)
|
||||
db.session.add(preferred_model_provider)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
|
||||
:param credential_form_schemas:
|
||||
:return:
|
||||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
||||
"""
|
||||
Obfuscated credentials.
|
||||
|
||||
:param credentials: credentials
|
||||
:param credential_form_schemas: credential form schemas
|
||||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self._extract_secret_variables(
|
||||
credential_form_schemas
|
||||
)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = credentials.copy()
|
||||
for key, value in copy_credentials.items():
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.obfuscated_token(value)
|
||||
|
||||
return copy_credentials
|
||||
|
||||
def get_provider_model(self, model_type: ModelType,
|
||||
model: str,
|
||||
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
|
||||
@ -522,118 +147,16 @@ class ProviderConfiguration(BaseModel):
|
||||
else:
|
||||
model_types = provider_instance.get_provider_schema().supported_model_types
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
provider_models = self._get_system_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance
|
||||
)
|
||||
else:
|
||||
provider_models = self._get_custom_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance
|
||||
)
|
||||
|
||||
provider_models = self._get_custom_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance
|
||||
)
|
||||
if only_active:
|
||||
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
|
||||
|
||||
# resort provider_models
|
||||
return sorted(provider_models, key=lambda x: x.model_type.value)
|
||||
|
||||
def _get_system_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get system provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_instance: provider instance
|
||||
:return:
|
||||
"""
|
||||
provider_models = []
|
||||
for model_type in model_types:
|
||||
provider_models.extend(
|
||||
[
|
||||
ModelWithProviderEntity(
|
||||
model=m.model,
|
||||
label=m.label,
|
||||
model_type=m.model_type,
|
||||
features=m.features,
|
||||
fetch_from=m.fetch_from,
|
||||
model_properties=m.model_properties,
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
)
|
||||
for m in provider_instance.models(model_type)
|
||||
]
|
||||
)
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
should_use_custom_model = False
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
should_use_custom_model = True
|
||||
|
||||
for quota_configuration in self.system_configuration.quota_configurations:
|
||||
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
||||
continue
|
||||
|
||||
restrict_models = quota_configuration.restrict_models
|
||||
if len(restrict_models) == 0:
|
||||
break
|
||||
|
||||
if should_use_custom_model:
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
if restrict_model.base_model_name:
|
||||
copy_credentials['base_model_name'] = restrict_model.base_model_name
|
||||
|
||||
try:
|
||||
custom_model_schema = (
|
||||
provider_instance.get_model_instance(restrict_model.model_type)
|
||||
.get_customizable_model_schema_from_credentials(
|
||||
restrict_model.model,
|
||||
copy_credentials
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.warning(f'get custom model schema failed, {ex}')
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
|
||||
if custom_model_schema.model_type not in model_types:
|
||||
continue
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
label=custom_model_schema.label,
|
||||
model_type=custom_model_schema.model_type,
|
||||
features=custom_model_schema.features,
|
||||
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
)
|
||||
)
|
||||
|
||||
# if llm name not in restricted llm list, remove it
|
||||
restrict_model_names = [rm.model for rm in restrict_models]
|
||||
for m in provider_models:
|
||||
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
|
||||
m.status = ModelStatus.NO_PERMISSION
|
||||
elif not quota_configuration.is_valid:
|
||||
m.status = ModelStatus.QUOTA_EXCEEDED
|
||||
return provider_models
|
||||
|
||||
def _get_custom_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||
@ -711,11 +234,10 @@ class ProviderConfigurations(BaseModel):
|
||||
"""
|
||||
Model class for provider configuration dict.
|
||||
"""
|
||||
tenant_id: str
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_models(self,
|
||||
provider: Optional[str] = None,
|
||||
|
||||
@ -4,22 +4,6 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderQuotaType
|
||||
|
||||
|
||||
class QuotaUnit(Enum):
|
||||
TIMES = 'times'
|
||||
TOKENS = 'tokens'
|
||||
CREDITS = 'credits'
|
||||
|
||||
|
||||
class SystemConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for system configuration status.
|
||||
"""
|
||||
ACTIVE = 'active'
|
||||
QUOTA_EXCEEDED = 'quota-exceeded'
|
||||
UNSUPPORTED = 'unsupported'
|
||||
|
||||
|
||||
class RestrictModel(BaseModel):
|
||||
@ -28,27 +12,6 @@ class RestrictModel(BaseModel):
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class QuotaConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider quota configuration.
|
||||
"""
|
||||
quota_type: ProviderQuotaType
|
||||
quota_unit: QuotaUnit
|
||||
quota_limit: int
|
||||
quota_used: int
|
||||
is_valid: bool
|
||||
restrict_models: list[RestrictModel] = []
|
||||
|
||||
|
||||
class SystemConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration.
|
||||
"""
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
credentials: Optional[dict] = None
|
||||
|
||||
|
||||
class CustomProviderConfiguration(BaseModel):
|
||||
"""
|
||||
|
||||
257
model-providers/chatchat_model_providers/core/model_manager.py
Normal file
257
model-providers/chatchat_model_providers/core/model_manager.py
Normal file
@ -0,0 +1,257 @@
|
||||
from collections.abc import Generator
|
||||
from typing import IO, Optional, Union, cast
|
||||
|
||||
from chatchat_model_providers.core.entities.provider_configuration import ProviderModelBundle
|
||||
from chatchat_model_providers.errors.error import ProviderTokenNotInitError
|
||||
from chatchat_model_providers.core.model_runtime.callbacks.base_callback import Callback
|
||||
from chatchat_model_providers.core.model_runtime.entities.llm_entities import LLMResult
|
||||
from chatchat_model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from chatchat_model_providers.core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from chatchat_model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from chatchat_model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from chatchat_model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from chatchat_model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from chatchat_model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from chatchat_model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from chatchat_model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from chatchat_model_providers.core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict:
|
||||
"""
|
||||
Fetch credentials from provider model bundle
|
||||
:param provider_model_bundle: provider model bundle
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=provider_model_bundle.model_type_instance.model_type,
|
||||
model=model
|
||||
)
|
||||
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
class ModelInstance:
|
||||
"""
|
||||
Model instance class
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
|
||||
self._provider_model_bundle = provider_model_bundle
|
||||
self.model = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = _fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.model_type_instance = self._provider_model_bundle.model_type_instance
|
||||
|
||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
||||
:param content_text: text content to be translated
|
||||
:param tenant_id: user tenant id
|
||||
:param user: unique user id
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
def get_tts_voices(self, language: str) -> list:
|
||||
"""
|
||||
Invoke large language tts model voices
|
||||
|
||||
:param language: tts language
|
||||
:return: tts model voices
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
language=language
|
||||
)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self,
|
||||
provider_name_to_provider_records_dict: dict,
|
||||
provider_name_to_provider_model_records_dict: dict) -> None:
|
||||
self._provider_manager = ProviderManager(
|
||||
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
|
||||
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict)
|
||||
|
||||
def get_model_instance(self, provider: str, model_type: ModelType, model: str) -> ModelInstance:
|
||||
"""
|
||||
Get model instance
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
if not provider:
|
||||
return self.get_default_model_instance(model_type)
|
||||
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
||||
provider=provider,
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
|
||||
def get_default_model_instance(self, model_type: ModelType) -> ModelInstance:
|
||||
"""
|
||||
Get default model instance
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
default_model_entity = self._provider_manager.get_default_model(
|
||||
model_type=model_type
|
||||
)
|
||||
|
||||
if not default_model_entity:
|
||||
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
|
||||
|
||||
return self.get_model_instance(
|
||||
provider=default_model_entity.provider.provider,
|
||||
model_type=model_type,
|
||||
model=default_model_entity.model
|
||||
)
|
||||
@ -0,0 +1,256 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from chatchat_model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from chatchat_model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, \
|
||||
ProviderModelBundle
|
||||
from chatchat_model_providers.core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
CustomModelConfiguration,
|
||||
CustomProviderConfiguration,
|
||||
)
|
||||
from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType
|
||||
from chatchat_model_providers.core.model_runtime.entities.provider_entities import (
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
"""
|
||||
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
provider_name_to_provider_records_dict: dict,
|
||||
provider_name_to_provider_model_records_dict: dict) -> None:
|
||||
self.provider_name_to_provider_records_dict = provider_name_to_provider_records_dict
|
||||
self.provider_name_to_provider_model_records_dict = provider_name_to_provider_model_records_dict
|
||||
|
||||
def get_configurations(self, provider: str) -> ProviderConfigurations:
|
||||
"""
|
||||
Get model provider configurations.
|
||||
|
||||
Construct ProviderConfiguration objects for each provider
|
||||
Including:
|
||||
1. Basic information of the provider
|
||||
2. Hosting configuration information, including:
|
||||
(1. Whether to enable (support) hosting type, if enabled, the following information exists
|
||||
(2. List of hosting type provider configurations
|
||||
(including quota type, quota limit, current remaining quota, etc.)
|
||||
(3. The current hosting type in use (whether there is a quota or not)
|
||||
paid quotas > provider free quotas > hosting trial quotas
|
||||
(4. Unified credentials for hosting providers
|
||||
3. Custom configuration information, including:
|
||||
(1. Whether to enable (support) custom type, if enabled, the following information exists
|
||||
(2. Custom provider configuration (including credentials)
|
||||
(3. List of custom provider model configurations (including credentials)
|
||||
4. Hosting/custom preferred provider type.
|
||||
Provide methods:
|
||||
- Get the current configuration (including credentials)
|
||||
- Get the availability and status of the hosting configuration: active available,
|
||||
quota_exceeded insufficient quota, unsupported hosting
|
||||
- Get the availability of custom configuration
|
||||
Custom provider available conditions:
|
||||
(1. custom provider credentials available
|
||||
(2. at least one custom model credentials available
|
||||
- Verify, update, and delete custom provider configuration
|
||||
- Verify, update, and delete custom provider model configuration
|
||||
- Get the list of available models (optional provider filtering, model type filtering)
|
||||
Append custom provider models to the list
|
||||
- Get provider instance
|
||||
- Switch selection priority
|
||||
|
||||
:param: provider
|
||||
:return:
|
||||
"""
|
||||
|
||||
# Get all provider entities
|
||||
provider_entities = model_provider_factory.get_providers(provider_name=provider)
|
||||
|
||||
provider_configurations = ProviderConfigurations()
|
||||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
for provider_entity in provider_entities:
|
||||
provider_name = provider_entity.provider
|
||||
|
||||
provider_credentials = self.provider_name_to_provider_records_dict.get(provider_entity.provider)
|
||||
if not provider_credentials:
|
||||
provider_credentials = {}
|
||||
|
||||
provider_model_records = self.provider_name_to_provider_model_records_dict.get(provider_entity.provider)
|
||||
if not provider_model_records:
|
||||
provider_model_records = []
|
||||
|
||||
# Convert to custom configuration
|
||||
custom_configuration = self._to_custom_configuration(
|
||||
provider_entity,
|
||||
provider_credentials,
|
||||
provider_model_records
|
||||
)
|
||||
|
||||
provider_configuration = ProviderConfiguration(
|
||||
provider=provider_entity,
|
||||
custom_configuration=custom_configuration
|
||||
)
|
||||
|
||||
provider_configurations[provider_name] = provider_configuration
|
||||
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
|
||||
def get_provider_model_bundle(self, provider: str, model_type: ModelType) -> ProviderModelBundle:
|
||||
"""
|
||||
Get provider model bundle.
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
provider_configurations = self.get_configurations(provider=provider)
|
||||
|
||||
# get provider instance
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
provider_instance = provider_configuration.get_provider_instance()
|
||||
model_type_instance = provider_instance.get_model_instance(model_type)
|
||||
|
||||
return ProviderModelBundle(
|
||||
configuration=provider_configuration,
|
||||
provider_instance=provider_instance,
|
||||
model_type_instance=model_type_instance
|
||||
)
|
||||
|
||||
def get_default_model(self, model_type: ModelType) -> Optional[DefaultModelEntity]:
|
||||
"""
|
||||
Get default model.
|
||||
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
|
||||
default_model = {}
|
||||
# Get provider configurations
|
||||
provider_configurations = self.get_configurations()
|
||||
|
||||
# get available models from provider_configurations
|
||||
available_models = provider_configurations.get_models(
|
||||
model_type=model_type,
|
||||
only_active=True
|
||||
)
|
||||
|
||||
if available_models:
|
||||
found = False
|
||||
for available_model in available_models:
|
||||
if available_model.model == "gpt-3.5-turbo-1106":
|
||||
default_model = {
|
||||
'provider_name': available_model.provider.provider,
|
||||
'model_name': available_model.model
|
||||
}
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
available_model = available_models[0]
|
||||
default_model = {
|
||||
'provider_name': available_model.provider.provider,
|
||||
'model_name': available_model.model
|
||||
}
|
||||
|
||||
provider_instance = model_provider_factory.get_provider_instance(default_model.get('provider_name'))
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
|
||||
return DefaultModelEntity(
|
||||
model=default_model.get("model_name"),
|
||||
model_type=model_type,
|
||||
provider=DefaultModelProviderEntity(
|
||||
provider=provider_schema.provider,
|
||||
label=provider_schema.label,
|
||||
icon_small=provider_schema.icon_small,
|
||||
icon_large=provider_schema.icon_large,
|
||||
supported_model_types=provider_schema.supported_model_types
|
||||
)
|
||||
)
|
||||
|
||||
def _to_custom_configuration(self,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_credentials: dict,
|
||||
provider_model_records: list[dict]) -> CustomConfiguration:
|
||||
"""
|
||||
Convert to custom configuration.
|
||||
|
||||
:param provider_entity: provider entity
|
||||
:param provider_credentials: provider records_credentials
|
||||
:param provider_model_records: provider model records_credentials
|
||||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema else []
|
||||
)
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
try:
|
||||
provider_credentials[variable] = provider_credentials.get(variable)
|
||||
except ValueError:
|
||||
pass
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema else []
|
||||
)
|
||||
|
||||
# Get custom provider model credentials
|
||||
custom_model_configurations = []
|
||||
for provider_model_record in provider_model_records:
|
||||
if not provider_model_record.get('model_credentials'):
|
||||
continue
|
||||
|
||||
provider_model_credentials = {}
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_record.get('model_credentials'):
|
||||
try:
|
||||
provider_model_credentials[variable] = provider_model_record.get('model_credentials').get(
|
||||
variable)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.get('model_name'),
|
||||
model_type=ModelType.value_of(provider_model_record.get('model_type')),
|
||||
credentials=provider_model_credentials
|
||||
)
|
||||
)
|
||||
|
||||
return CustomConfiguration(
|
||||
provider=custom_provider_configuration,
|
||||
models=custom_model_configurations
|
||||
)
|
||||
|
||||
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
|
||||
:param credential_form_schemas:
|
||||
:return:
|
||||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
38
model-providers/chatchat_model_providers/errors/error.py
Normal file
38
model-providers/chatchat_model_providers/errors/error.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""Base class for all LLM exceptions."""
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
self.description = description
|
||||
|
||||
|
||||
class LLMBadRequestError(LLMError):
|
||||
"""Raised when the LLM returns bad request."""
|
||||
description = "Bad Request"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
description = "Provider Token Not Init"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
description = "Model Currently Not Support"
|
||||
Loading…
x
Reference in New Issue
Block a user