# Released under the MIT License. See LICENSE for details. # """Functionality for dataclassio related to exporting data from dataclasses.""" # Note: We do lots of comparing of exact types here which is normally # frowned upon (stuff like isinstance() is usually encouraged). # pylint: disable=unidiomatic-typecheck from __future__ import annotations from enum import Enum import dataclasses import typing import datetime from typing import TYPE_CHECKING from efro.util import check_utc from efro.dataclassio._base import (Codec, _parse_annotated, EXTRA_ATTRS_ATTR, _is_valid_for_codec, _get_origin, SIMPLE_TYPES, _raise_type_error, IOExtendedData) from efro.dataclassio._prep import PrepSession if TYPE_CHECKING: from typing import Any, Optional from efro.dataclassio._base import IOAttrs class _Outputter: """Validates or exports data contained in a dataclass instance.""" def __init__(self, obj: Any, create: bool, codec: Codec, coerce_to_float: bool) -> None: self._obj = obj self._create = create self._codec = codec self._coerce_to_float = coerce_to_float def run(self) -> Any: """Do the thing.""" # For special extended data types, call their 'will_output' callback. if isinstance(self._obj, IOExtendedData): self._obj.will_output() return self._process_dataclass(type(self._obj), self._obj, '') def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any: # pylint: disable=too-many-locals # pylint: disable=too-many-branches prep = PrepSession(explicit=False).prep_dataclass(type(obj), recursion_level=0) assert prep is not None fields = dataclasses.fields(obj) out: Optional[dict[str, Any]] = {} if self._create else None for field in fields: fieldname = field.name if fieldpath: subfieldpath = f'{fieldpath}.{fieldname}' else: subfieldpath = fieldname anntype = prep.annotations[fieldname] value = getattr(obj, fieldname) anntype, ioattrs = _parse_annotated(anntype) # If we're not storing default values for this fella, # we can skip all output processing if we've got a default value. if ioattrs is not None and not ioattrs.store_default: default_factory: Any = field.default_factory if default_factory is not dataclasses.MISSING: if default_factory() == value: continue elif field.default is not dataclasses.MISSING: if field.default == value: continue else: raise RuntimeError( f'Field {fieldname} of {cls.__name__} has' f' neither a default nor a default_factory;' f' store_default=False cannot be set for it.' f' (AND THIS SHOULD HAVE BEEN CAUGHT IN PREP!)') outvalue = self._process_value(cls, subfieldpath, anntype, value, ioattrs) if self._create: assert out is not None storagename = (fieldname if (ioattrs is None or ioattrs.storagename is None) else ioattrs.storagename) out[storagename] = outvalue # If there's extra-attrs stored on us, check/include them. extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None) if isinstance(extra_attrs, dict): if not _is_valid_for_codec(extra_attrs, self._codec): raise TypeError( f'Extra attrs on {fieldpath} contains data type(s)' f' not supported by json.') if self._create: assert out is not None out.update(extra_attrs) return out def _process_value(self, cls: type, fieldpath: str, anntype: Any, value: Any, ioattrs: Optional[IOAttrs]) -> Any: # pylint: disable=too-many-return-statements # pylint: disable=too-many-branches # pylint: disable=too-many-statements origin = _get_origin(anntype) if origin is typing.Any: if not _is_valid_for_codec(value, self._codec): raise TypeError( f'Invalid value type for \'{fieldpath}\';' f" 'Any' typed values must contain types directly" f' supported by the specified codec ({self._codec.name});' f' found \'{type(value).__name__}\' which is not.') return value if self._create else None if origin is typing.Union: # Currently, the only unions we support are None/Value # (translated from Optional), which we verified on prep. # So let's treat this as a simple optional case. if value is None: return None childanntypes_l = [ c for c in typing.get_args(anntype) if c is not type(None) ] # noqa (pycodestyle complains about *is* with type) assert len(childanntypes_l) == 1 return self._process_value(cls, fieldpath, childanntypes_l[0], value, ioattrs) # Everything below this point assumes the annotation type resolves # to a concrete type. (This should have been verified at prep time). assert isinstance(origin, type) # For simple flat types, look for exact matches: if origin in SIMPLE_TYPES: if type(value) is not origin: # Special case: if they want to coerce ints to floats, do so. if (self._coerce_to_float and origin is float and type(value) is int): return float(value) if self._create else None _raise_type_error(fieldpath, type(value), (origin, )) return value if self._create else None if origin is tuple: if not isinstance(value, tuple): raise TypeError(f'Expected a tuple for {fieldpath};' f' found a {type(value)}') childanntypes = typing.get_args(anntype) # We should have verified this was non-zero at prep-time assert childanntypes if len(value) != len(childanntypes): raise TypeError(f'Tuple at {fieldpath} contains' f' {len(value)} values; type specifies' f' {len(childanntypes)}.') if self._create: return [ self._process_value(cls, fieldpath, childanntypes[i], x, ioattrs) for i, x in enumerate(value) ] for i, x in enumerate(value): self._process_value(cls, fieldpath, childanntypes[i], x, ioattrs) return None if origin is list: if not isinstance(value, list): raise TypeError(f'Expected a list for {fieldpath};' f' found a {type(value)}') childanntypes = typing.get_args(anntype) # 'Any' type children; make sure they are valid values for # the specified codec. if len(childanntypes) == 0 or childanntypes[0] is typing.Any: for i, child in enumerate(value): if not _is_valid_for_codec(child, self._codec): raise TypeError( f'Item {i} of {fieldpath} contains' f' data type(s) not supported by the specified' f' codec ({self._codec.name}).') # Hmm; should we do a copy here? return value if self._create else None # We contain elements of some specified type. assert len(childanntypes) == 1 if self._create: return [ self._process_value(cls, fieldpath, childanntypes[0], x, ioattrs) for x in value ] for x in value: self._process_value(cls, fieldpath, childanntypes[0], x, ioattrs) return None if origin is set: if not isinstance(value, set): raise TypeError(f'Expected a set for {fieldpath};' f' found a {type(value)}') childanntypes = typing.get_args(anntype) # 'Any' type children; make sure they are valid Any values. if len(childanntypes) == 0 or childanntypes[0] is typing.Any: for child in value: if not _is_valid_for_codec(child, self._codec): raise TypeError( f'Set at {fieldpath} contains' f' data type(s) not supported by the' f' specified codec ({self._codec.name}).') return list(value) if self._create else None # We contain elements of some specified type. assert len(childanntypes) == 1 if self._create: # Note: we output json-friendly values so this becomes # a list. return [ self._process_value(cls, fieldpath, childanntypes[0], x, ioattrs) for x in value ] for x in value: self._process_value(cls, fieldpath, childanntypes[0], x, ioattrs) return None if origin is dict: return self._process_dict(cls, fieldpath, anntype, value, ioattrs) if dataclasses.is_dataclass(origin): if not isinstance(value, origin): raise TypeError(f'Expected a {origin} for {fieldpath};' f' found a {type(value)}.') return self._process_dataclass(cls, value, fieldpath) if issubclass(origin, Enum): if not isinstance(value, origin): raise TypeError(f'Expected a {origin} for {fieldpath};' f' found a {type(value)}.') # At prep-time we verified that these enums had valid value # types, so we can blindly return it here. return value.value if self._create else None if issubclass(origin, datetime.datetime): if not isinstance(value, origin): raise TypeError(f'Expected a {origin} for {fieldpath};' f' found a {type(value)}.') check_utc(value) if ioattrs is not None: ioattrs.validate_datetime(value, fieldpath) if self._codec is Codec.FIRESTORE: return value assert self._codec is Codec.JSON return [ value.year, value.month, value.day, value.hour, value.minute, value.second, value.microsecond ] if self._create else None if origin is bytes: return self._process_bytes(cls, fieldpath, value) raise TypeError( f"Field '{fieldpath}' of type '{anntype}' is unsupported here.") def _process_bytes(self, cls: type, fieldpath: str, value: bytes) -> Any: import base64 if not isinstance(value, bytes): raise TypeError( f'Expected bytes for {fieldpath} on {cls.__name__};' f' found a {type(value)}.') if not self._create: return None # In JSON we convert to base64, but firestore directly supports bytes. if self._codec is Codec.JSON: return base64.b64encode(value).decode() assert self._codec is Codec.FIRESTORE return value def _process_dict(self, cls: type, fieldpath: str, anntype: Any, value: dict, ioattrs: Optional[IOAttrs]) -> Any: # pylint: disable=too-many-branches if not isinstance(value, dict): raise TypeError(f'Expected a dict for {fieldpath};' f' found a {type(value)}.') childtypes = typing.get_args(anntype) assert len(childtypes) in (0, 2) # We treat 'Any' dicts simply as json; we don't do any translating. if not childtypes or childtypes[0] is typing.Any: if not isinstance(value, dict) or not _is_valid_for_codec( value, self._codec): raise TypeError( f'Invalid value for Dict[Any, Any]' f' at \'{fieldpath}\' on {cls.__name__};' f' all keys and values must be directly compatible' f' with the specified codec ({self._codec.name})' f' when dict type is Any.') return value if self._create else None # Ok; we've got a definite key type (which we verified as valid # during prep). Make sure all keys match it. out: Optional[dict] = {} if self._create else None keyanntype, valanntype = childtypes # str keys we just export directly since that's supported by json. if keyanntype is str: for key, val in value.items(): if not isinstance(key, str): raise TypeError( f'Got invalid key type {type(key)} for' f' dict key at \'{fieldpath}\' on {cls.__name__};' f' expected {keyanntype}.') outval = self._process_value(cls, fieldpath, valanntype, val, ioattrs) if self._create: assert out is not None out[key] = outval # int keys are stored as str versions of themselves. elif keyanntype is int: for key, val in value.items(): if not isinstance(key, int): raise TypeError( f'Got invalid key type {type(key)} for' f' dict key at \'{fieldpath}\' on {cls.__name__};' f' expected an int.') outval = self._process_value(cls, fieldpath, valanntype, val, ioattrs) if self._create: assert out is not None out[str(key)] = outval elif issubclass(keyanntype, Enum): for key, val in value.items(): if not isinstance(key, keyanntype): raise TypeError( f'Got invalid key type {type(key)} for' f' dict key at \'{fieldpath}\' on {cls.__name__};' f' expected a {keyanntype}.') outval = self._process_value(cls, fieldpath, valanntype, val, ioattrs) if self._create: assert out is not None out[str(key.value)] = outval else: raise RuntimeError(f'Unhandled dict out-key-type {keyanntype}') return out