diff --git a/server/schema_validators/__init__.py b/server/schema_validators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/schema_validators/common_validator.py b/server/schema_validators/common_validator.py new file mode 100644 index 00000000..fe705d69 --- /dev/null +++ b/server/schema_validators/common_validator.py @@ -0,0 +1,87 @@ +from typing import Optional + +from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType + + +class CommonValidator: + def _validate_and_filter_credential_form_schemas(self, + credential_form_schemas: list[CredentialFormSchema], + credentials: dict) -> dict: + need_validate_credential_form_schema_map = {} + for credential_form_schema in credential_form_schemas: + if not credential_form_schema.show_on: + need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema + continue + + all_show_on_match = True + for show_on_object in credential_form_schema.show_on: + if show_on_object.variable not in credentials: + all_show_on_match = False + break + + if credentials[show_on_object.variable] != show_on_object.value: + all_show_on_match = False + break + + if all_show_on_match: + need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema + + # Iterate over the remaining credential_form_schemas, verify each credential_form_schema + validated_credentials = {} + for credential_form_schema in need_validate_credential_form_schema_map.values(): + # add the value of the credential_form_schema corresponding to it to validated_credentials + result = self._validate_credential_form_schema(credential_form_schema, credentials) + if result: + validated_credentials[credential_form_schema.variable] = result + + return validated_credentials + + def _validate_credential_form_schema(self, credential_form_schema: CredentialFormSchema, credentials: dict) \ + -> Optional[str]: + """ + Validate credential form schema + + :param credential_form_schema: credential form schema + :param credentials: credentials + :return: validated credential form schema value + """ + # If the variable does not exist in credentials + if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: + # If required is True, an exception is thrown + if credential_form_schema.required: + raise ValueError(f'Variable {credential_form_schema.variable} is required') + else: + # Get the value of default + if credential_form_schema.default: + # If it exists, add it to validated_credentials + return credential_form_schema.default + else: + # If default does not exist, skip + return None + + # Get the value corresponding to the variable from credentials + value = credentials[credential_form_schema.variable] + + # If max_length=0, no validation is performed + if credential_form_schema.max_length: + if len(value) > credential_form_schema.max_length: + raise ValueError(f'Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}') + + # check the type of value + if not isinstance(value, str): + raise ValueError(f'Variable {credential_form_schema.variable} should be string') + + if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: + # If the value is in options, no validation is performed + if credential_form_schema.options: + if value not in [option.value for option in credential_form_schema.options]: + raise ValueError(f'Variable {credential_form_schema.variable} is not in options') + + if credential_form_schema.type == FormType.SWITCH: + # If the value is not in ['true', 'false'], an exception is thrown + if value.lower() not in ['true', 'false']: + raise ValueError(f'Variable {credential_form_schema.variable} should be true or false') + + value = True if value.lower() == 'true' else False + + return value diff --git a/server/schema_validators/model_credential_schema_validator.py b/server/schema_validators/model_credential_schema_validator.py new file mode 100644 index 00000000..c4786fad --- /dev/null +++ b/server/schema_validators/model_credential_schema_validator.py @@ -0,0 +1,28 @@ +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ModelCredentialSchema +from core.model_runtime.schema_validators.common_validator import CommonValidator + + +class ModelCredentialSchemaValidator(CommonValidator): + + def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): + self.model_type = model_type + self.model_credential_schema = model_credential_schema + + def validate_and_filter(self, credentials: dict) -> dict: + """ + Validate model credentials + + :param credentials: model credentials + :return: filtered credentials + """ + + if self.model_credential_schema is None: + raise ValueError("Model credential schema is None") + + # get the credential_form_schemas in provider_credential_schema + credential_form_schemas = self.model_credential_schema.credential_form_schemas + + credentials["__model_type"] = self.model_type.value + + return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/server/schema_validators/provider_credential_schema_validator.py b/server/schema_validators/provider_credential_schema_validator.py new file mode 100644 index 00000000..c9450165 --- /dev/null +++ b/server/schema_validators/provider_credential_schema_validator.py @@ -0,0 +1,20 @@ +from core.model_runtime.entities.provider_entities import ProviderCredentialSchema +from core.model_runtime.schema_validators.common_validator import CommonValidator + + +class ProviderCredentialSchemaValidator(CommonValidator): + + def __init__(self, provider_credential_schema: ProviderCredentialSchema): + self.provider_credential_schema = provider_credential_schema + + def validate_and_filter(self, credentials: dict) -> dict: + """ + Validate provider credentials + + :param credentials: provider credentials + :return: validated provider credentials + """ + # get the credential_form_schemas in provider_credential_schema + credential_form_schemas = self.provider_credential_schema.credential_form_schemas + + return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/server/utils/__init__.py b/server/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/utils/_compat.py b/server/utils/_compat.py new file mode 100644 index 00000000..5c341527 --- /dev/null +++ b/server/utils/_compat.py @@ -0,0 +1,21 @@ +from typing import Any, Literal + +from pydantic import BaseModel +from pydantic.version import VERSION as PYDANTIC_VERSION + +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + +if PYDANTIC_V2: + from pydantic_core import Url as Url + + def _model_dump( + model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any + ) -> Any: + return model.model_dump(mode=mode, **kwargs) +else: + from pydantic import AnyUrl as Url # noqa: F401 + + def _model_dump( + model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any + ) -> Any: + return model.dict(**kwargs) diff --git a/server/utils/encoders.py b/server/utils/encoders.py new file mode 100644 index 00000000..cf6c98e0 --- /dev/null +++ b/server/utils/encoders.py @@ -0,0 +1,228 @@ +import dataclasses +import datetime +from collections import defaultdict, deque +from collections.abc import Callable +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path, PurePath +from re import Pattern +from types import GeneratorType +from typing import Any, Optional, Union +from uuid import UUID + +from pydantic import BaseModel +from pydantic.color import Color +from pydantic.networks import AnyUrl, NameEmail +from pydantic.types import SecretBytes, SecretStr + +from ._compat import PYDANTIC_V2, Url, _model_dump + + +# Taken from Pydantic v1 as is +def isoformat(o: Union[datetime.date, datetime.time]) -> str: + return o.isoformat() + + +# Taken from Pydantic v1 as is +# TODO: pv2 should this return strings instead? +def decimal_encoder(dec_value: Decimal) -> Union[int, float]: + """ + Encodes a Decimal as int of there's no exponent, otherwise float + + This is useful when we use ConstrainedDecimal to represent Numeric(x,0) + where a integer (but not int typed) is used. Encoding this as a float + results in failed round-tripping between encode and parse. + Our Id type is a prime example of this. + + >>> decimal_encoder(Decimal("1.0")) + 1.0 + + >>> decimal_encoder(Decimal("1")) + 1 + """ + if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] + return int(dec_value) + else: + return float(dec_value) + + +ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { + bytes: lambda o: o.decode(), + Color: str, + datetime.date: isoformat, + datetime.datetime: isoformat, + datetime.time: isoformat, + datetime.timedelta: lambda td: td.total_seconds(), + Decimal: decimal_encoder, + Enum: lambda o: o.value, + frozenset: list, + deque: list, + GeneratorType: list, + IPv4Address: str, + IPv4Interface: str, + IPv4Network: str, + IPv6Address: str, + IPv6Interface: str, + IPv6Network: str, + NameEmail: str, + Path: str, + Pattern: lambda o: o.pattern, + SecretBytes: str, + SecretStr: str, + set: list, + UUID: str, + Url: str, + AnyUrl: str, +} + + +def generate_encoders_by_class_tuples( + type_encoder_map: dict[Any, Callable[[Any], Any]] +) -> dict[Callable[[Any], Any], tuple[Any, ...]]: + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( + tuple + ) + for type_, encoder in type_encoder_map.items(): + encoders_by_class_tuples[encoder] += (type_,) + return encoders_by_class_tuples + + +encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) + + +def jsonable_encoder( + obj: Any, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, + sqlalchemy_safe: bool = True, +) -> Any: + custom_encoder = custom_encoder or {} + if custom_encoder: + if type(obj) in custom_encoder: + return custom_encoder[type(obj)](obj) + else: + for encoder_type, encoder_instance in custom_encoder.items(): + if isinstance(obj, encoder_type): + return encoder_instance(obj) + if isinstance(obj, BaseModel): + # TODO: remove when deprecating Pydantic v1 + encoders: dict[Any, Any] = {} + if not PYDANTIC_V2: + encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] + if custom_encoder: + encoders.update(custom_encoder) + obj_dict = _model_dump( + obj, + mode="json", + include=None, + exclude=None, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + if "__root__" in obj_dict: + obj_dict = obj_dict["__root__"] + return jsonable_encoder( + obj_dict, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + # TODO: remove when deprecating Pydantic v1 + custom_encoder=encoders, + sqlalchemy_safe=sqlalchemy_safe, + ) + if dataclasses.is_dataclass(obj): + obj_dict = dataclasses.asdict(obj) + return jsonable_encoder( + obj_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, PurePath): + return str(obj) + if isinstance(obj, str | int | float | type(None)): + return obj + if isinstance(obj, Decimal): + return format(obj, 'f') + if isinstance(obj, dict): + encoded_dict = {} + allowed_keys = set(obj.keys()) + for key, value in obj.items(): + if ( + ( + not sqlalchemy_safe + or (not isinstance(key, str)) + or (not key.startswith("_sa")) + ) + and (value is not None or not exclude_none) + and key in allowed_keys + ): + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + encoded_dict[encoded_key] = encoded_value + return encoded_dict + if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): + encoded_list = [] + for item in obj: + encoded_list.append( + jsonable_encoder( + item, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) + ) + return encoded_list + + if type(obj) in ENCODERS_BY_TYPE: + return ENCODERS_BY_TYPE[type(obj)](obj) + for encoder, classes_tuple in encoders_by_class_tuples.items(): + if isinstance(obj, classes_tuple): + return encoder(obj) + + try: + data = dict(obj) + except Exception as e: + errors: list[Exception] = [] + errors.append(e) + try: + data = vars(obj) + except Exception as e: + errors.append(e) + raise ValueError(errors) from e + return jsonable_encoder( + data, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) diff --git a/server/utils/helper.py b/server/utils/helper.py new file mode 100644 index 00000000..09d08fa3 --- /dev/null +++ b/server/utils/helper.py @@ -0,0 +1,9 @@ +import pydantic +from pydantic import BaseModel + + +def dump_model(model: BaseModel) -> dict: + if hasattr(pydantic, 'model_dump'): + return pydantic.model_dump(model) + else: + return model.dict()