2024-04-23 13:44:40 -07:00

581 lines
23 KiB
Python

# Released under the MIT License. See LICENSE for details.
#
"""Functionality for dataclassio related to exporting data from dataclasses."""
# Note: We do lots of comparing of exact types here which is normally
# frowned upon (stuff like isinstance() is usually encouraged).
# pylint: disable=unidiomatic-typecheck
from __future__ import annotations
from enum import Enum
import dataclasses
import typing
import types
import json
import datetime
from typing import TYPE_CHECKING, cast, Any
from efro.util import check_utc
from efro.dataclassio._base import (
Codec,
_parse_annotated,
EXTRA_ATTRS_ATTR,
_is_valid_for_codec,
_get_origin,
SIMPLE_TYPES,
_raise_type_error,
IOExtendedData,
IOMultiType,
)
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from efro.dataclassio._base import IOAttrs
class _Outputter:
"""Validates or exports data contained in a dataclass instance."""
def __init__(
self, obj: Any, create: bool, codec: Codec, coerce_to_float: bool
) -> None:
self._obj = obj
self._create = create
self._codec = codec
self._coerce_to_float = coerce_to_float
def run(self) -> Any:
"""Do the thing."""
assert dataclasses.is_dataclass(self._obj)
# For special extended data types, call their 'will_output' callback.
# FIXME - should probably move this into _process_dataclass so it
# can work on nested values.
if isinstance(self._obj, IOExtendedData):
self._obj.will_output()
return self._process_dataclass(type(self._obj), self._obj, '')
def soft_default_check(
self, value: Any, anntype: Any, fieldpath: str
) -> None:
"""(internal)"""
self._process_value(
type(value),
fieldpath=fieldpath,
anntype=anntype,
value=value,
ioattrs=None,
)
def _process_dataclass(self, cls: type, obj: Any, fieldpath: str) -> Any:
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
prep = PrepSession(explicit=False).prep_dataclass(
type(obj), recursion_level=0
)
assert prep is not None
fields = dataclasses.fields(obj)
out: dict[str, Any] | None = {} if self._create else None
for field in fields:
fieldname = field.name
if fieldpath:
subfieldpath = f'{fieldpath}.{fieldname}'
else:
subfieldpath = fieldname
anntype = prep.annotations[fieldname]
value = getattr(obj, fieldname)
anntype, ioattrs = _parse_annotated(anntype)
# If we're not storing default values for this fella,
# we can skip all output processing if we've got a default value.
if ioattrs is not None and not ioattrs.store_default:
# If both soft_defaults and regular field defaults
# are present we want to go with soft_defaults since
# those same values would be re-injected when reading
# the same data back in if we've omitted the field.
default_factory: Any = field.default_factory
if ioattrs.soft_default is not ioattrs.MISSING:
if ioattrs.soft_default == value:
continue
elif ioattrs.soft_default_factory is not ioattrs.MISSING:
assert callable(ioattrs.soft_default_factory)
if ioattrs.soft_default_factory() == value:
continue
elif field.default is not dataclasses.MISSING:
if field.default == value:
continue
elif default_factory is not dataclasses.MISSING:
if default_factory() == value:
continue
else:
raise RuntimeError(
f'Field {fieldname} of {cls.__name__} has'
f' no source of default values; store_default=False'
f' cannot be set for it. (AND THIS SHOULD HAVE BEEN'
f' CAUGHT IN PREP!)'
)
outvalue = self._process_value(
cls, subfieldpath, anntype, value, ioattrs
)
if self._create:
assert out is not None
storagename = (
fieldname
if (ioattrs is None or ioattrs.storagename is None)
else ioattrs.storagename
)
out[storagename] = outvalue
# If there's extra-attrs stored on us, check/include them.
extra_attrs = getattr(obj, EXTRA_ATTRS_ATTR, None)
if isinstance(extra_attrs, dict):
if not _is_valid_for_codec(extra_attrs, self._codec):
raise TypeError(
f'Extra attrs on \'{fieldpath}\' contains data type(s)'
f' not supported by \'{self._codec.value}\' codec:'
f' {extra_attrs}.'
)
if self._create:
assert out is not None
out.update(extra_attrs)
# If this obj inherits from multi-type, store its type id.
if isinstance(obj, IOMultiType):
type_id = obj.get_type_id()
# Sanity checks; make sure looking up this id gets us this
# type.
assert isinstance(type_id.value, str)
if obj.get_type(type_id) is not type(obj):
raise RuntimeError(
f'dataclassio: object of type {type(obj)}'
f' gives type-id {type_id} but that id gives type'
f' {obj.get_type(type_id)}. Something is out of sync.'
)
assert obj.get_type(type_id) is type(obj)
if self._create:
assert out is not None
storagename = obj.get_type_id_storage_name()
if any(f.name == storagename for f in fields):
raise RuntimeError(
f'dataclassio: {type(obj)} contains a'
f" '{storagename}' field which clashes with"
f' the type-id-storage-name of the IOMulticlass'
f' it inherits from.'
)
out[storagename] = type_id.value
return out
def _process_value(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: Any,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
origin = _get_origin(anntype)
if origin is typing.Any:
if not _is_valid_for_codec(value, self._codec):
raise TypeError(
f'Invalid value type for \'{fieldpath}\';'
f" 'Any' typed values must contain types directly"
f' supported by the specified codec ({self._codec.name});'
f' found \'{type(value).__name__}\' which is not.'
)
return value if self._create else None
if origin is typing.Union or origin is types.UnionType:
# Currently, the only unions we support are None/Value
# (translated from Optional), which we verified on prep.
# So let's treat this as a simple optional case.
if value is None:
return None
childanntypes_l = [
c for c in typing.get_args(anntype) if c is not type(None)
] # noqa (pycodestyle complains about *is* with type)
assert len(childanntypes_l) == 1
return self._process_value(
cls, fieldpath, childanntypes_l[0], value, ioattrs
)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
assert isinstance(origin, type)
# For simple flat types, look for exact matches:
if origin in SIMPLE_TYPES:
if type(value) is not origin:
# Special case: if they want to coerce ints to floats, do so.
if (
self._coerce_to_float
and origin is float
and type(value) is int
):
return float(value) if self._create else None
_raise_type_error(fieldpath, type(value), (origin,))
return value if self._create else None
if origin is tuple:
if not isinstance(value, tuple):
raise TypeError(
f'Expected a tuple for {fieldpath};'
f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# We should have verified this was non-zero at prep-time
assert childanntypes
if len(value) != len(childanntypes):
raise TypeError(
f'Tuple at {fieldpath} contains'
f' {len(value)} values; type specifies'
f' {len(childanntypes)}.'
)
if self._create:
return [
self._process_value(
cls, fieldpath, childanntypes[i], x, ioattrs
)
for i, x in enumerate(value)
]
for i, x in enumerate(value):
self._process_value(
cls, fieldpath, childanntypes[i], x, ioattrs
)
return None
if origin is list:
if not isinstance(value, list):
raise TypeError(
f'Expected a list for {fieldpath};'
f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid values for
# the specified codec.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for i, child in enumerate(value):
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Item {i} of {fieldpath} contains'
f' data type(s) not supported by the specified'
f' codec ({self._codec.name}).'
)
# Hmm; should we do a copy here?
return value if self._create else None
# We contain elements of some single specified type.
assert len(childanntypes) == 1
childanntype = childanntypes[0]
# If that type is a multi-type, we determine our type per-object.
if issubclass(childanntype, IOMultiType):
# In the multi-type case, we use each object's own type
# to do its conversion, but lets at least make sure each
# of those types inherits from the annotated multi-type
# class.
for x in value:
if not isinstance(x, childanntype):
raise ValueError(
f"Found a {type(x)} value under '{fieldpath}'."
f' Everything must inherit from'
f' {childanntype}.'
)
if self._create:
out: list[Any] = []
for x in value:
# We know these are dataclasses so no need to do
# the generic _process_value.
out.append(self._process_dataclass(cls, x, fieldpath))
return out
for x in value:
# We know these are dataclasses so no need to do
# the generic _process_value.
self._process_dataclass(cls, x, fieldpath)
# Normal non-multitype case; everything's got the same type.
if self._create:
return [
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
for x in value
]
for x in value:
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
return None
if origin is set:
if not isinstance(value, set):
raise TypeError(
f'Expected a set for {fieldpath};' f' found a {type(value)}'
)
childanntypes = typing.get_args(anntype)
# 'Any' type children; make sure they are valid Any values.
if len(childanntypes) == 0 or childanntypes[0] is typing.Any:
for child in value:
if not _is_valid_for_codec(child, self._codec):
raise TypeError(
f'Set at {fieldpath} contains'
f' data type(s) not supported by the'
f' specified codec ({self._codec.name}).'
)
# We output json-friendly values so this becomes a list.
# We need to sort the list so our output is
# deterministic and can be meaningfully compared with
# others, across processes, etc.
#
# Since we don't know what types we've got here, we
# guarantee sortability by dumping each value to a json
# string (itself with keys sorted) and using that as the
# value's sorting key. Not efficient but it works. A
# good reason to avoid set[Any] though. Perhaps we
# should just disallow it altogether.
return (
sorted(value, key=lambda v: json.dumps(v, sort_keys=True))
if self._create
else None
)
# We contain elements of some specified type.
assert len(childanntypes) == 1
if self._create:
# We output json-friendly values so this becomes a list.
# We need to sort the list so our output is
# deterministic and can be meaningfully compared with
# others, across processes, etc.
#
# In this case we have a single concrete type, and for
# most incarnations of that (str, int, etc.) we can just
# sort our final output. For more complex cases,
# however, such as optional values or dataclasses, we
# need to convert everything to a json string (itself
# with keys sorted) and sort based on those strings.
# This is probably a good reason to avoid sets
# containing dataclasses or optional values. Perhaps we
# should just disallow those.
return sorted(
(
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
for x in value
),
key=(
None
if childanntypes[0] in [str, int, float, bool]
else lambda v: json.dumps(v, sort_keys=True)
),
)
for x in value:
self._process_value(
cls, fieldpath, childanntypes[0], x, ioattrs
)
return None
if origin is dict:
return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
if dataclasses.is_dataclass(origin):
if not isinstance(value, cast(Any, origin)):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
return self._process_dataclass(cls, value, fieldpath)
# ONLY consider something as a multi-type when it's not a
# dataclass (all dataclasses inheriting from the multi-type should
# just be processed as dataclasses).
if issubclass(origin, IOMultiType):
# In the multi-type case, we use each object's own type to
# do its conversion, but lets at least make sure each of
# those types inherits from the annotated multi-type class.
if not isinstance(value, origin):
raise ValueError(
f"Found a {type(value)} value at '{fieldpath}'."
f' It is expected to inherit from {origin}.'
)
return self._process_dataclass(cls, value, fieldpath)
if issubclass(origin, Enum):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
# At prep-time we verified that these enums had valid value
# types, so we can blindly return it here.
return value.value if self._create else None
if issubclass(origin, datetime.datetime):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
check_utc(value)
if ioattrs is not None:
ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE:
return value
assert self._codec is Codec.JSON
return (
[
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
value.microsecond,
]
if self._create
else None
)
if issubclass(origin, datetime.timedelta):
if not isinstance(value, origin):
raise TypeError(
f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.'
)
return (
[value.days, value.seconds, value.microseconds]
if self._create
else None
)
if origin is bytes:
return self._process_bytes(cls, fieldpath, value)
raise TypeError(
f"Field '{fieldpath}' of type '{anntype}' is unsupported here."
)
def _process_bytes(self, cls: type, fieldpath: str, value: bytes) -> Any:
import base64
if not isinstance(value, bytes):
raise TypeError(
f'Expected bytes for {fieldpath} on {cls.__name__};'
f' found a {type(value)}.'
)
if not self._create:
return None
# In JSON we convert to base64, but firestore directly supports bytes.
if self._codec is Codec.JSON:
return base64.b64encode(value).decode()
assert self._codec is Codec.FIRESTORE
return value
def _process_dict(
self,
cls: type,
fieldpath: str,
anntype: Any,
value: dict,
ioattrs: IOAttrs | None,
) -> Any:
# pylint: disable=too-many-branches
if not isinstance(value, dict):
raise TypeError(
f'Expected a dict for {fieldpath};' f' found a {type(value)}.'
)
childtypes = typing.get_args(anntype)
assert len(childtypes) in (0, 2)
# We treat 'Any' dicts simply as json; we don't do any translating.
value_any: Any = value
if not childtypes or childtypes[0] is typing.Any:
if not isinstance(value_any, dict) or not _is_valid_for_codec(
value, self._codec
):
raise TypeError(
f'Invalid value for Dict[Any, Any]'
f' at \'{fieldpath}\' on {cls.__name__};'
f' all keys and values must be directly compatible'
f' with the specified codec ({self._codec.name})'
f' when dict type is Any.'
)
return value if self._create else None
# Ok; we've got a definite key type (which we verified as valid
# during prep). Make sure all keys match it.
out: dict | None = {} if self._create else None
keyanntype, valanntype = childtypes
# str keys we just export directly since that's supported by json.
if keyanntype is str:
for key, val in value.items():
if not isinstance(key, str):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected {keyanntype}.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[key] = outval
# int keys are stored as str versions of themselves.
elif keyanntype is int:
for key, val in value.items():
if not isinstance(key, int):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected an int.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[str(key)] = outval
elif issubclass(keyanntype, Enum):
for key, val in value.items():
if not isinstance(key, keyanntype):
raise TypeError(
f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls.__name__};'
f' expected a {keyanntype}.'
)
outval = self._process_value(
cls, fieldpath, valanntype, val, ioattrs
)
if self._create:
assert out is not None
out[str(key.value)] = outval
else:
raise RuntimeError(f'Unhandled dict out-key-type {keyanntype}')
return out