2021-09-29 11:19:17 -05:00

349 lines
15 KiB
Python

# Released under the MIT License. See LICENSE for details.
#
"""Functionality for dataclassio related to exporting data from dataclasses."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
from enum import Enum
import dataclasses
import typing
import datetime
from typing import TYPE_CHECKING
from efro.dataclassio._base import (Codec, _parse_annotated, EXTRA_ATTRS_ATTR,
_is_valid_for_codec, _get_origin,
SIMPLE_TYPES, _raise_type_error,
_ensure_datetime_is_timezone_aware)
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from typing import Any, Dict, Type, Tuple, Optional, List, Set
from efro.dataclassio._base import IOAttrs
class _Outputter:
"""Validates or exports data contained in a dataclass instance."""
def __init__(self, obj: Any, create: bool, codec: Codec,
coerce_to_float: bool) -> None:
self._obj = obj
self._create = create
self._codec = codec
self._coerce_to_float = coerce_to_float
def run(self) -> Any:
"""Do the thing."""
return self._process_dataclass(type(self._obj), self._obj, '')
def _process_dataclass(self, cls: Type, obj: Any, fieldpath: str) -> Any:
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
prep = PrepSession(explicit=False).prep_dataclass(type(obj),
recursion_level=0)
fields = dataclasses.fields(obj)
out: Optional[Dict[str, Any]] = {} if self._create else None
for field in fields:
fieldname = field.name
if fieldpath:
subfieldpath = f'{fieldpath}.{fieldname}'
else:
subfieldpath = fieldname
anntype = prep.annotations[fieldname]
value = getattr(obj, fieldname)
anntype, ioattrs = _parse_annotated(anntype)
# If we're not storing default values for this fella,
# we can skip all output processing if we've got a default value.
if ioattrs is not None and not ioattrs.store_default:
default_factory: Any = field.default_factory # type: ignore
if default_factory is not dataclasses.MISSING:
if default_factory() == value:
continue
elif field.default is not dataclasses.MISSING:
if field.default == value:
continue
else:
raise RuntimeError(
f'Field {fieldname} of {cls.__name__} has'
f' neither a default nor a default_factory;'
f' store_default=False cannot be set for it.'
f' (AND THIS SHOULD HAVE BEEN CAUGHT IN PREP!)')
outvalue = self._process_value(cls, subfieldpath, anntype, value,
ioattrs)
if self._create:
assert out is not None
storagename = (fieldname if
(ioattrs is None or ioattrs.storagename is None)
else ioattrs.storagename)
out[storagename] = outvalue
# If there's extra-attrs stored on us, check/include them.
extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None)
if isinstance(extra_attrs, dict):
if not _is_valid_for_codec(extra_attrs, self._codec):
raise TypeError(
f'Extra attrs on {fieldpath} contains data type(s)'
f' not supported by json.')
if self._create:
assert out is not None
out.update(extra_attrs)
return out
def _process_value(self, cls: Type, fieldpath: str, anntype: Any,
value: Any, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
origin = _get_origin(anntype)
if origin is typing.Any:
if not _is_valid_for_codec(value, self._codec):
raise TypeError(
f'Invalid value type for \'{fieldpath}\';'
f" 'Any' typed values must contain types directly"
f' supported by the specified codec ({self._codec.name});'
f' found \'{type(value).__name__}\' which is not.')
return value if self._create else None
if origin is typing.Union:
# Currently the only unions we support are None/Value
# (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case.
if value is None:
return None
childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None)
]
assert len(childanntypes_l) == 1
return self._process_value(cls, fieldpath, childanntypes_l[0],
value, ioattrs)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
assert isinstance(origin, type)
# For simple flat types, look for exact matches:
if origin in SIMPLE_TYPES:
if type(value) is not origin:
# Special case: if they want to coerce ints to floats, do so.
if (self._coerce_to_float and origin is float
and type(value) is int):
return float(value) if self._create else None
_raise_type_error(fieldpath, type(value), (origin, ))
return value if self._create else None
if origin is tuple:
if not isinstance(value, tuple):
raise TypeError(f'Expected a tuple for {fieldpath};'
f' found a {type(value)}')
childanntypes = typing.get_args(anntype)
# We should have verified this was non-zero at prep-time
assert childanntypes
if len(value) != len(childanntypes):
raise TypeError(f'Tuple at {fieldpath} contains'
f' {len(value)} values; type specifies'
f' {len(childanntypes)}.')
if self._create:
return [
self._process_value(cls, fieldpath, childanntypes[i], x,
ioattrs) for i, x in enumerate(value)
]
for i, x in enumerate(value):
self._process_value(cls, fieldpath, childanntypes[i], x,
ioattrs)
return None
if origin is list:
if not isinstance(value, list):
raise TypeError(f'Expected a list for {fieldpath};'
f' found a {type(value)}')
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid values for
# the specified codec.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for i, child in enumerate(value):
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Item {i} of {fieldpath} contains'
f' data type(s) not supported by the specified'
f' codec ({self._codec.name}).')
# Hmm; should we do a copy here?
return value if self._create else None
# We contain elements of some specified type.
assert len(childanntypes) == 1
if self._create:
return [
self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs) for x in value
]
for x in value:
self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs)
return None
if origin is set:
if not isinstance(value, set):
raise TypeError(f'Expected a set for {fieldpath};'
f' found a {type(value)}')
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid Any values.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for child in value:
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Set at {fieldpath} contains'
f' data type(s) not supported by the'
f' specified codec ({self._codec.name}).')
return list(value) if self._create else None
# We contain elements of some specified type.
assert len(childanntypes) == 1
if self._create:
# Note: we output json-friendly values so this becomes
# a list.
return [
self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs) for x in value
]
for x in value:
self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs)
return None
if origin is dict:
return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
if dataclasses.is_dataclass(origin):
if not isinstance(value, origin):
raise TypeError(f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.')
return self._process_dataclass(cls, value, fieldpath)
if issubclass(origin, Enum):
if not isinstance(value, origin):
raise TypeError(f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.')
# At prep-time we verified that these enums had valid value
# types, so we can blindly return it here.
return value.value if self._create else None
if issubclass(origin, datetime.datetime):
if not isinstance(value, origin):
raise TypeError(f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.')
_ensure_datetime_is_timezone_aware(value)
if ioattrs is not None:
ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE:
return value
assert self._codec is Codec.JSON
return [
value.year, value.month, value.day, value.hour, value.minute,
value.second, value.microsecond
] if self._create else None
if origin is bytes:
return self._process_bytes(cls, fieldpath, value)
raise TypeError(
f"Field '{fieldpath}' of type '{anntype}' is unsupported here.")
def _process_bytes(self, cls: Type, fieldpath: str, value: bytes) -> Any:
import base64
if not isinstance(value, bytes):
raise TypeError(
f'Expected bytes for {fieldpath} on {cls.__name__};'
f' found a {type(value)}.')
if not self._create:
return None
# In JSON we convert to base64, but firestore directly supports bytes.
if self._codec is Codec.JSON:
return base64.b64encode(value).decode()
assert self._codec is Codec.FIRESTORE
return value
def _process_dict(self, cls: Type, fieldpath: str, anntype: Any,
value: dict, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-branches
if not isinstance(value, dict):
raise TypeError(f'Expected a dict for {fieldpath};'
f' found a {type(value)}.')
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
# We treat 'Any' dicts simply as json; we don't do any translating.
if not childtypes or childtypes[0] is typing.Any:
if not isinstance(value, dict) or not _is_valid_for_codec(
value, self._codec):
raise TypeError(
f'Invalid value for Dict[Any, Any]'
f' at \'{fieldpath}\' on {cls.__name__};'
f' all keys and values must be directly compatible'
f' with the specified codec ({self._codec.name})'
f' when dict type is Any.')
return value if self._create else None
# Ok; we've got a definite key type (which we verified as valid
# during prep). Make sure all keys match it.
out: Optional[Dict] = {} if self._create else None
keyanntype, valanntype = childtypes
# str keys we just export directly since that's supported by json.
if keyanntype is str:
for key, val in value.items():
if not isinstance(key, str):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected {keyanntype}.')
outval = self._process_value(cls, fieldpath, valanntype, val,
ioattrs)
if self._create:
assert out is not None
out[key] = outval
# int keys are stored as str versions of themselves.
elif keyanntype is int:
for key, val in value.items():
if not isinstance(key, int):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected an int.')
outval = self._process_value(cls, fieldpath, valanntype, val,
ioattrs)
if self._create:
assert out is not None
out[str(key)] = outval
elif issubclass(keyanntype, Enum):
for key, val in value.items():
if not isinstance(key, keyanntype):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected a {keyanntype}.')
outval = self._process_value(cls, fieldpath, valanntype, val,
ioattrs)
if self._create:
assert out is not None
out[str(key.value)] = outval
else:
raise RuntimeError(f'Unhandled dict out-key-type {keyanntype}')
return out