Merge pull request #3523 from khazic/dev_model_providers

模型默认参数从配置文件加载
This commit is contained in:
glide-the 2024-03-26 16:15:34 +08:00 committed by GitHub
commit 1e96d69945
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 583 additions and 556 deletions

View File

@ -1,37 +1,28 @@
import os
from typing import cast, Generator
from chatchat_model_providers.core.model_manager import ModelManager
from chatchat_model_providers.core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta
from chatchat_model_providers.core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType
if __name__ == '__main__':
# 基于配置管理器创建的模型实例
# provider_manager = ProviderManager()
provider_configurations = ProviderConfigurations(
tenant_id=tenant_id
provider_manager = ModelManager(
provider_name_to_provider_records_dict={
'openai': {
'openai_api_key': "sk-4M9LYF",
}
},
provider_name_to_provider_model_records_dict={}
)
#
# model_instance = ModelInstance(
# provider_model_bundle=provider_model_bundle,
# model=model_config.model,
# )
# 直接通过模型加载器创建的模型实例
from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory
model_provider_factory.get_providers(provider_name='openai')
provider_instance = model_provider_factory.get_provider_instance('openai')
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
print(model_type_instance)
response = model_type_instance.invoke(
model='gpt-4',
credentials={
'openai_api_key': "sk-",
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
},
# Invoke model
model_instance = provider_manager.get_model_instance(provider='openai', model_type=ModelType.LLM, model='gpt-4')
response = model_instance.invoke_llm(
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'

View File

@ -8,9 +8,7 @@ from typing import Optional
from pydantic import BaseModel
from chatchat_model_providers.core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from chatchat_model_providers.core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
from chatchat_model_providers.core.helper import encrypter
from chatchat_model_providers.core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from chatchat_model_providers.core.entities.provider_entities import CustomConfiguration
from chatchat_model_providers.core.model_runtime.entities.model_entities import FetchFrom, ModelType
from chatchat_model_providers.core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
@ -21,39 +19,20 @@ from chatchat_model_providers.core.model_runtime.entities.provider_entities impo
from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory
from chatchat_model_providers.core.model_runtime.model_providers.__base.ai_model import AIModel
from chatchat_model_providers.core.model_runtime.model_providers.__base.model_provider import ModelProvider
from chatchat_model_providers.extensions.ext_database import db
from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
logger = logging.getLogger(__name__)
original_provider_configurate_methods = {}
class ProviderConfiguration(BaseModel):
"""
Model class for provider configuration.
"""
tenant_id: str
provider: ProviderEntity
preferred_provider_type: ProviderType
using_provider_type: ProviderType
system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration
def __init__(self, **data):
super().__init__(**data)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (any([len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations])
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
"""
Get current credentials.
@ -62,58 +41,15 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
if self.using_provider_type == ProviderType.SYSTEM:
restrict_models = []
for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
return model_configuration.credentials
restrict_models = quota_configuration.restrict_models
copy_credentials = self.system_configuration.credentials.copy()
if restrict_models:
for restrict_model in restrict_models:
if (restrict_model.model_type == model_type
and restrict_model.model == model
and restrict_model.base_model_name):
copy_credentials['base_model_name'] = restrict_model.base_model_name
return copy_credentials
if self.custom_configuration.provider:
return self.custom_configuration.provider.credentials
else:
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
return model_configuration.credentials
if self.custom_configuration.provider:
return self.custom_configuration.provider.credentials
else:
return None
def get_system_configuration_status(self) -> SystemConfigurationStatus:
"""
Get system configuration status.
:return:
"""
if self.system_configuration.enabled is False:
return SystemConfigurationStatus.UNSUPPORTED
current_quota_type = self.system_configuration.current_quota_type
current_quota_configuration = next(
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
None
)
return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
SystemConfigurationStatus.QUOTA_EXCEEDED
def is_custom_configuration_available(self) -> bool:
"""
Check custom configuration available.
:return:
"""
return (self.custom_configuration.provider is not None
or len(self.custom_configuration.models) > 0)
return None
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
"""
@ -129,130 +65,9 @@ class ProviderConfiguration(BaseModel):
if not obfuscated:
return credentials
# Obfuscate credentials
return self._obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
)
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
)
if provider_record:
try:
# fix origin data
if provider_record.encrypted_config:
if not provider_record.encrypted_config.startswith("{"):
original_credentials = {
"openai_api_key": provider_record.encrypted_config
}
else:
original_credentials = json.loads(provider_record.encrypted_config)
else:
original_credentials = {}
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.provider_credentials_validate(
self.provider.provider,
credentials
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return provider_record, credentials
def add_or_update_custom_credentials(self, credentials: dict) -> None:
"""
Add or update custom provider credentials.
:param credentials:
:return:
"""
# validate custom provider config
provider_record, credentials = self.custom_credentials_validate(credentials)
# save provider
# Note: Do not switch the preferred provider, which allows users to use quotas first
if provider_record:
provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True
provider_record.updated_at = datetime.datetime.utcnow()
db.session.commit()
else:
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(credentials),
is_valid=True
)
db.session.add(provider_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(ProviderType.CUSTOM)
def delete_custom_credentials(self) -> None:
"""
Delete custom provider credentials.
:return:
"""
# get provider
provider_record = db.session.query(Provider) \
.filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.provider.provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()
# delete provider
if provider_record:
self.switch_preferred_provider_type(ProviderType.SYSTEM)
db.session.delete(provider_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER
)
provider_model_credentials_cache.delete()
# Obfuscate provider credentials
copy_credentials = credentials.copy()
return copy_credentials
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
-> Optional[dict]:
@ -272,136 +87,12 @@ class ProviderConfiguration(BaseModel):
credentials = model_configuration.credentials
if not obfuscated:
return credentials
copy_credentials = credentials.copy()
# Obfuscate credentials
return self._obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
)
return copy_credentials
return None
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
-> tuple[ProviderModel, dict]:
"""
Validate custom model credentials.
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
)
if provider_model_record:
try:
original_credentials = json.loads(
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
except JSONDecodeError:
original_credentials = {}
# decrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider,
model_type=model_type,
model=model,
credentials=credentials
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return provider_model_record, credentials
def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
"""
Add or update custom model credentials.
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
# validate custom model config
provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
# save provider model
# Note: Do not switch the preferred provider, which allows users to use quotas first
if provider_model_record:
provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True
provider_model_record.updated_at = datetime.datetime.utcnow()
db.session.commit()
else:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
is_valid=True
)
db.session.add(provider_model_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
provider_model_credentials_cache.delete()
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
"""
Delete custom model credentials.
:param model_type: model type
:param model: model name
:return:
"""
# get provider model
provider_model_record = db.session.query(ProviderModel) \
.filter(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name == self.provider.provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type()
).first()
# delete provider model
if provider_model_record:
db.session.delete(provider_model_record)
db.session.commit()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL
)
provider_model_credentials_cache.delete()
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
@ -422,72 +113,6 @@ class ProviderConfiguration(BaseModel):
# Get model instance of LLM
return provider_instance.get_model_instance(model_type)
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
"""
Switch preferred provider type.
:param provider_type:
:return:
"""
if provider_type == self.preferred_provider_type:
return
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
return
# get preferred provider
preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
.filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name == self.provider.provider
).first()
if preferred_model_provider:
preferred_model_provider.preferred_provider_type = provider_type.value
else:
preferred_model_provider = TenantPreferredModelProvider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value
)
db.session.add(preferred_model_provider)
db.session.commit()
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
"""
Obfuscated credentials.
:param credentials: credentials
:param credential_form_schemas: credential form schemas
:return:
"""
# Get provider credential secret variables
credential_secret_variables = self._extract_secret_variables(
credential_form_schemas
)
# Obfuscate provider credentials
copy_credentials = credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
return copy_credentials
def get_provider_model(self, model_type: ModelType,
model: str,
only_active: bool = False) -> Optional[ModelWithProviderEntity]:
@ -522,118 +147,16 @@ class ProviderConfiguration(BaseModel):
else:
model_types = provider_instance.get_provider_schema().supported_model_types
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types,
provider_instance=provider_instance
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_instance=provider_instance
)
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_instance=provider_instance
)
if only_active:
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
# resort provider_models
return sorted(provider_models, key=lambda x: x.model_type.value)
def _get_system_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
"""
Get system provider models.
:param model_types: model types
:param provider_instance: provider instance
:return:
"""
provider_models = []
for model_type in model_types:
provider_models.extend(
[
ModelWithProviderEntity(
model=m.model,
label=m.label,
model_type=m.model_type,
features=m.features,
fetch_from=m.fetch_from,
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
for m in provider_instance.models(model_type)
]
)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
should_use_custom_model = False
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
should_use_custom_model = True
for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue
restrict_models = quota_configuration.restrict_models
if len(restrict_models) == 0:
break
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials['base_model_name'] = restrict_model.base_model_name
try:
custom_model_schema = (
provider_instance.get_model_instance(restrict_model.model_type)
.get_customizable_model_schema_from_credentials(
restrict_model.model,
copy_credentials
)
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
continue
if not custom_model_schema:
continue
if custom_model_schema.model_type not in model_types:
continue
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
)
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for m in provider_models:
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
def _get_custom_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
@ -711,11 +234,10 @@ class ProviderConfigurations(BaseModel):
"""
Model class for provider configuration dict.
"""
tenant_id: str
configurations: dict[str, ProviderConfiguration] = {}
def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
def __init__(self):
super().__init__()
def get_models(self,
provider: Optional[str] = None,

View File

@ -4,22 +4,6 @@ from typing import Optional
from pydantic import BaseModel
from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType
class QuotaUnit(Enum):
TIMES = 'times'
TOKENS = 'tokens'
CREDITS = 'credits'
class SystemConfigurationStatus(Enum):
"""
Enum class for system configuration status.
"""
ACTIVE = 'active'
QUOTA_EXCEEDED = 'quota-exceeded'
UNSUPPORTED = 'unsupported'
class RestrictModel(BaseModel):
@ -28,27 +12,6 @@ class RestrictModel(BaseModel):
model_type: ModelType
class QuotaConfiguration(BaseModel):
"""
Model class for provider quota configuration.
"""
quota_type: ProviderQuotaType
quota_unit: QuotaUnit
quota_limit: int
quota_used: int
is_valid: bool
restrict_models: list[RestrictModel] = []
class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
quota_configurations: list[QuotaConfiguration] = []
credentials: Optional[dict] = None
class CustomProviderConfiguration(BaseModel):
"""

View File

@ -0,0 +1,257 @@
from collections.abc import Generator
from typing import IO, Optional, Union, cast
from chatchat_model_providers.core.entities.provider_configuration import ProviderModelBundle
from chatchat_model_providers.errors.error import ProviderTokenNotInitError
from chatchat_model_providers.core.model_runtime.callbacks.base_callback import Callback
from chatchat_model_providers.core.model_runtime.entities.llm_entities import LLMResult
from chatchat_model_providers.core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType
from chatchat_model_providers.core.model_runtime.entities.rerank_entities import RerankResult
from chatchat_model_providers.core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from chatchat_model_providers.core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from chatchat_model_providers.core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from chatchat_model_providers.core.model_runtime.model_providers.__base.rerank_model import RerankModel
from chatchat_model_providers.core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from chatchat_model_providers.core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from chatchat_model_providers.core.model_runtime.model_providers.__base.tts_model import TTSModel
from chatchat_model_providers.core.provider_manager import ProviderManager
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict:
"""
Fetch credentials from provider model bundle
:param provider_model_bundle: provider model bundle
:param model: model name
:return:
"""
credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=provider_model_bundle.model_type_instance.model_type,
model=model
)
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
return credentials
class ModelInstance:
"""
Model instance class
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
self._provider_model_bundle = provider_model_bundle
self.model = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = _fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self._provider_model_bundle.model_type_instance
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks
)
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
texts=texts,
user=user
)
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
user=user
)
def invoke_moderation(self, text: str, user: Optional[str] = None) \
-> bool:
"""
Invoke moderation model
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
text=text,
user=user
)
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke large language model
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
file=file,
user=user
)
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
-> str:
"""
Invoke large language tts model
:param content_text: text content to be translated
:param tenant_id: user tenant id
:param user: unique user id
:param voice: model timbre
:param streaming: output is streaming
:return: text for given audio file
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.invoke(
model=self.model,
credentials=self.credentials,
content_text=content_text,
user=user,
tenant_id=tenant_id,
voice=voice,
streaming=streaming
)
def get_tts_voices(self, language: str) -> list:
"""
Invoke large language tts model voices
:param language: tts language
:return: tts model voices
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model,
credentials=self.credentials,
language=language
)
class ModelManager:
def __init__(self,
provider_name_to_provider_records_dict: dict,
provider_name_to_provider_model_records_dict: dict) -> None:
self._provider_manager = ProviderManager(
provider_name_to_provider_records_dict=provider_name_to_provider_records_dict,
provider_name_to_provider_model_records_dict=provider_name_to_provider_model_records_dict)
def get_model_instance(self, provider: str, model_type: ModelType, model: str) -> ModelInstance:
"""
Get model instance
:param provider: provider name
:param model_type: model type
:param model: model name
:return:
"""
if not provider:
return self.get_default_model_instance(model_type)
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
provider=provider,
model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
def get_default_model_instance(self, model_type: ModelType) -> ModelInstance:
"""
Get default model instance
:param model_type: model type
:return:
"""
default_model_entity = self._provider_manager.get_default_model(
model_type=model_type
)
if not default_model_entity:
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
return self.get_model_instance(
provider=default_model_entity.provider.provider,
model_type=model_type,
model=default_model_entity.model
)

View File

@ -0,0 +1,256 @@
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Optional
from sqlalchemy.exc import IntegrityError
from chatchat_model_providers.core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from chatchat_model_providers.core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, \
ProviderModelBundle
from chatchat_model_providers.core.entities.provider_entities import (
CustomConfiguration,
CustomModelConfiguration,
CustomProviderConfiguration,
)
from chatchat_model_providers.core.model_runtime.entities.model_entities import ModelType
from chatchat_model_providers.core.model_runtime.entities.provider_entities import (
CredentialFormSchema,
FormType,
ProviderEntity,
)
from chatchat_model_providers.core.model_runtime.model_providers import model_provider_factory
class ProviderManager:
"""
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
"""
def __init__(self,
provider_name_to_provider_records_dict: dict,
provider_name_to_provider_model_records_dict: dict) -> None:
self.provider_name_to_provider_records_dict = provider_name_to_provider_records_dict
self.provider_name_to_provider_model_records_dict = provider_name_to_provider_model_records_dict
def get_configurations(self, provider: str) -> ProviderConfigurations:
"""
Get model provider configurations.
Construct ProviderConfiguration objects for each provider
Including:
1. Basic information of the provider
2. Hosting configuration information, including:
(1. Whether to enable (support) hosting type, if enabled, the following information exists
(2. List of hosting type provider configurations
(including quota type, quota limit, current remaining quota, etc.)
(3. The current hosting type in use (whether there is a quota or not)
paid quotas > provider free quotas > hosting trial quotas
(4. Unified credentials for hosting providers
3. Custom configuration information, including:
(1. Whether to enable (support) custom type, if enabled, the following information exists
(2. Custom provider configuration (including credentials)
(3. List of custom provider model configurations (including credentials)
4. Hosting/custom preferred provider type.
Provide methods:
- Get the current configuration (including credentials)
- Get the availability and status of the hosting configuration: active available,
quota_exceeded insufficient quota, unsupported hosting
- Get the availability of custom configuration
Custom provider available conditions:
(1. custom provider credentials available
(2. at least one custom model credentials available
- Verify, update, and delete custom provider configuration
- Verify, update, and delete custom provider model configuration
- Get the list of available models (optional provider filtering, model type filtering)
Append custom provider models to the list
- Get provider instance
- Switch selection priority
:param: provider
:return:
"""
# Get all provider entities
provider_entities = model_provider_factory.get_providers(provider_name=provider)
provider_configurations = ProviderConfigurations()
# Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities:
provider_name = provider_entity.provider
provider_credentials = self.provider_name_to_provider_records_dict.get(provider_entity.provider)
if not provider_credentials:
provider_credentials = {}
provider_model_records = self.provider_name_to_provider_model_records_dict.get(provider_entity.provider)
if not provider_model_records:
provider_model_records = []
# Convert to custom configuration
custom_configuration = self._to_custom_configuration(
provider_entity,
provider_credentials,
provider_model_records
)
provider_configuration = ProviderConfiguration(
provider=provider_entity,
custom_configuration=custom_configuration
)
provider_configurations[provider_name] = provider_configuration
# Return the encapsulated object
return provider_configurations
def get_provider_model_bundle(self, provider: str, model_type: ModelType) -> ProviderModelBundle:
"""
Get provider model bundle.
:param provider: provider name
:param model_type: model type
:return:
"""
provider_configurations = self.get_configurations(provider=provider)
# get provider instance
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
provider_instance = provider_configuration.get_provider_instance()
model_type_instance = provider_instance.get_model_instance(model_type)
return ProviderModelBundle(
configuration=provider_configuration,
provider_instance=provider_instance,
model_type_instance=model_type_instance
)
def get_default_model(self, model_type: ModelType) -> Optional[DefaultModelEntity]:
"""
Get default model.
:param model_type: model type
:return:
"""
default_model = {}
# Get provider configurations
provider_configurations = self.get_configurations()
# get available models from provider_configurations
available_models = provider_configurations.get_models(
model_type=model_type,
only_active=True
)
if available_models:
found = False
for available_model in available_models:
if available_model.model == "gpt-3.5-turbo-1106":
default_model = {
'provider_name': available_model.provider.provider,
'model_name': available_model.model
}
found = True
break
if not found:
available_model = available_models[0]
default_model = {
'provider_name': available_model.provider.provider,
'model_name': available_model.model
}
provider_instance = model_provider_factory.get_provider_instance(default_model.get('provider_name'))
provider_schema = provider_instance.get_provider_schema()
return DefaultModelEntity(
model=default_model.get("model_name"),
model_type=model_type,
provider=DefaultModelProviderEntity(
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types
)
)
def _to_custom_configuration(self,
provider_entity: ProviderEntity,
provider_credentials: dict,
provider_model_records: list[dict]) -> CustomConfiguration:
"""
Convert to custom configuration.
:param provider_entity: provider entity
:param provider_credentials: provider records_credentials
:param provider_model_records: provider model records_credentials
:return:
"""
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema else []
)
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
try:
provider_credentials[variable] = provider_credentials.get(variable)
except ValueError:
pass
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials
)
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema else []
)
# Get custom provider model credentials
custom_model_configurations = []
for provider_model_record in provider_model_records:
if not provider_model_record.get('model_credentials'):
continue
provider_model_credentials = {}
for variable in model_credential_secret_variables:
if variable in provider_model_record.get('model_credentials'):
try:
provider_model_credentials[variable] = provider_model_record.get('model_credentials').get(
variable)
except ValueError:
pass
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.get('model_name'),
model_type=ModelType.value_of(provider_model_record.get('model_type')),
credentials=provider_model_credentials
)
)
return CustomConfiguration(
provider=custom_provider_configuration,
models=custom_model_configurations
)
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -0,0 +1,38 @@
from typing import Optional
class LLMError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
self.description = description
class LLMBadRequestError(LLMError):
"""Raised when the LLM returns bad request."""
description = "Bad Request"
class ProviderTokenNotInitError(Exception):
"""
Custom exception raised when the provider token is not initialized.
"""
description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception):
"""
Custom exception raised when the quota for a provider has been exceeded.
"""
description = "Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
"""
Custom exception raised when the model not support
"""
description = "Model Currently Not Support"