Added day and hour value limits to dataclassio

This commit is contained in:
Eric Froemling 2021-05-27 12:36:35 -05:00
parent dc7d7fab1d
commit 2f2ae11d4f
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
6 changed files with 138 additions and 40 deletions

View File

@ -2095,6 +2095,7 @@
<w>stickman</w> <w>stickman</w>
<w>storable</w> <w>storable</w>
<w>storagename</w> <w>storagename</w>
<w>storagenames</w>
<w>storecmd</w> <w>storecmd</w>
<w>storedhash</w> <w>storedhash</w>
<w>storeitemui</w> <w>storeitemui</w>

View File

@ -941,6 +941,7 @@
<w>stephane</w> <w>stephane</w>
<w>stepnum</w> <w>stepnum</w>
<w>stepsize</w> <w>stepsize</w>
<w>storagenames</w>
<w>storecmd</w> <w>storecmd</w>
<w>strcasecmp</w> <w>strcasecmp</w>
<w>strchr</w> <w>strchr</w>

View File

@ -1,5 +1,5 @@
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND --> <!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
<h4><em>last updated on 2021-05-26 for Ballistica version 1.6.4 build 20369</em></h4> <h4><em>last updated on 2021-05-27 for Ballistica version 1.6.4 build 20369</em></h4>
<p>This page documents the Python classes and functions in the 'ba' module, <p>This page documents the Python classes and functions in the 'ba' module,
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p> which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
<hr> <hr>

View File

@ -683,3 +683,44 @@ def test_field_storage_path_capture() -> None:
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
assert namecap.nonexistent.barf == 's.barf' assert namecap.nonexistent.barf == 's.barf'
def test_datetime_limits() -> None:
"""Test limiting datetime values in various ways."""
from efro.util import utc_today, utc_this_hour
@ioprepped
@dataclass
class _TestClass:
tval: Annotated[datetime.datetime, IOAttrs(whole_hours=True)]
# Check whole-hour limit when validating/exporting.
obj = _TestClass(tval=utc_this_hour() + datetime.timedelta(minutes=1))
with pytest.raises(ValueError):
dataclass_validate(obj)
obj.tval = utc_this_hour()
dataclass_validate(obj)
# Check whole-days limit when importing.
out = dataclass_to_dict(obj)
out['tval'][-1] += 1
with pytest.raises(ValueError):
dataclass_from_dict(_TestClass, out)
# Check whole-days limit when validating/exporting.
@ioprepped
@dataclass
class _TestClass2:
tval: Annotated[datetime.datetime, IOAttrs(whole_days=True)]
obj2 = _TestClass2(tval=utc_today() + datetime.timedelta(hours=1))
with pytest.raises(ValueError):
dataclass_validate(obj2)
obj2.tval = utc_today()
dataclass_validate(obj2)
# Check whole-days limit when importing.
out = dataclass_to_dict(obj2)
out['tval'][-1] += 1
with pytest.raises(ValueError):
dataclass_from_dict(_TestClass2, out)

View File

@ -72,9 +72,15 @@ class Codec(Enum):
class IOAttrs: class IOAttrs:
"""For specifying io behavior in annotations.""" """For specifying io behavior in annotations."""
def __init__(self, storagename: str = None, store_default: bool = True): def __init__(self,
storagename: str = None,
store_default: bool = True,
whole_days: bool = False,
whole_hours: bool = False):
self.storagename = storagename self.storagename = storagename
self.store_default = store_default self.store_default = store_default
self.whole_days = whole_days
self.whole_hours = whole_hours
def validate_for_field(self, cls: Type, field: dataclasses.Field) -> None: def validate_for_field(self, cls: Type, field: dataclasses.Field) -> None:
"""Ensure the IOAttrs instance is ok to use with the provided field.""" """Ensure the IOAttrs instance is ok to use with the provided field."""
@ -89,13 +95,27 @@ class IOAttrs:
f' neither a default nor a default_factory;' f' neither a default nor a default_factory;'
f' store_default=False cannot be set for it.') f' store_default=False cannot be set for it.')
def validate_datetime(self, value: datetime.datetime,
fieldpath: str) -> None:
"""Ensure a datetime value meets our value requirements."""
if self.whole_days:
if any(x != 0 for x in (value.hour, value.minute, value.second,
value.microsecond)):
raise ValueError(
f'Value {value} at {fieldpath} is not a whole day.')
if self.whole_hours:
if any(x != 0
for x in (value.minute, value.second, value.microsecond)):
raise ValueError(f'Value {value} at {fieldpath}'
f' is not a whole hour.')
class FieldStoragePathCapture: class FieldStoragePathCapture:
"""Utility for obtaining dataclass storage paths in a type safe way. """Utility for obtaining dataclass storage paths in a type safe way.
Given dataclass instance foo, FieldStoragePathCapture(foo).bar.eep Given dataclass instance foo, FieldStoragePathCapture(foo).bar.eep
will return 'bar.eep' (or something like 'b.e' if storagenames are will return 'bar.eep' (or something like 'b.e' if storagenames are
overrridden). This can be combined with type-checking tricks that overridden). This can be combined with type-checking tricks that
return foo in the type-checker's eyes while returning return foo in the type-checker's eyes while returning
FieldStoragePathCapture(foo) at runtime in order to grant a measure FieldStoragePathCapture(foo) at runtime in order to grant a measure
of type safety to specifying field paths for things such as db of type safety to specifying field paths for things such as db
@ -627,7 +647,8 @@ class _Outputter:
f' store_default=False cannot be set for it.' f' store_default=False cannot be set for it.'
f' (AND THIS SHOULD HAVE BEEN CAUGHT IN PREP!)') f' (AND THIS SHOULD HAVE BEEN CAUGHT IN PREP!)')
outvalue = self._process_value(cls, subfieldpath, anntype, value) outvalue = self._process_value(cls, subfieldpath, anntype, value,
ioattrs)
if self._create: if self._create:
assert out is not None assert out is not None
storagename = (fieldname if storagename = (fieldname if
@ -648,7 +669,7 @@ class _Outputter:
return out return out
def _process_value(self, cls: Type, fieldpath: str, anntype: Any, def _process_value(self, cls: Type, fieldpath: str, anntype: Any,
value: Any) -> Any: value: Any, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
# pylint: disable=too-many-statements # pylint: disable=too-many-statements
@ -674,7 +695,7 @@ class _Outputter:
] ]
assert len(childanntypes_l) == 1 assert len(childanntypes_l) == 1
return self._process_value(cls, fieldpath, childanntypes_l[0], return self._process_value(cls, fieldpath, childanntypes_l[0],
value) value, ioattrs)
# Everything below this point assumes the annotation type resolves # Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time). # to a concrete type. (This should have been verified at prep time).
@ -704,11 +725,12 @@ class _Outputter:
f' {len(childanntypes)}.') f' {len(childanntypes)}.')
if self._create: if self._create:
return [ return [
self._process_value(cls, fieldpath, childanntypes[i], x) self._process_value(cls, fieldpath, childanntypes[i], x,
for i, x in enumerate(value) ioattrs) for i, x in enumerate(value)
] ]
for i, x in enumerate(value): for i, x in enumerate(value):
self._process_value(cls, fieldpath, childanntypes[i], x) self._process_value(cls, fieldpath, childanntypes[i], x,
ioattrs)
return None return None
if origin is list: if origin is list:
@ -731,11 +753,12 @@ class _Outputter:
assert len(childanntypes) == 1 assert len(childanntypes) == 1
if self._create: if self._create:
return [ return [
self._process_value(cls, fieldpath, childanntypes[0], x) self._process_value(cls, fieldpath, childanntypes[0], x,
for x in value ioattrs) for x in value
] ]
for x in value: for x in value:
self._process_value(cls, fieldpath, childanntypes[0], x) self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs)
return None return None
if origin is set: if origin is set:
@ -759,15 +782,16 @@ class _Outputter:
# Note: we output json-friendly values so this becomes # Note: we output json-friendly values so this becomes
# a list. # a list.
return [ return [
self._process_value(cls, fieldpath, childanntypes[0], x) self._process_value(cls, fieldpath, childanntypes[0], x,
for x in value ioattrs) for x in value
] ]
for x in value: for x in value:
self._process_value(cls, fieldpath, childanntypes[0], x) self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs)
return None return None
if origin is dict: if origin is dict:
return self._process_dict(cls, fieldpath, anntype, value) return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
if dataclasses.is_dataclass(origin): if dataclasses.is_dataclass(origin):
if not isinstance(value, origin): if not isinstance(value, origin):
@ -788,6 +812,8 @@ class _Outputter:
raise TypeError(f'Expected a {origin} for {fieldpath};' raise TypeError(f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.') f' found a {type(value)}.')
_ensure_datetime_is_timezone_aware(value) _ensure_datetime_is_timezone_aware(value)
if ioattrs is not None:
ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE: if self._codec is Codec.FIRESTORE:
return value return value
assert self._codec is Codec.JSON assert self._codec is Codec.JSON
@ -819,7 +845,7 @@ class _Outputter:
return value return value
def _process_dict(self, cls: Type, fieldpath: str, anntype: Any, def _process_dict(self, cls: Type, fieldpath: str, anntype: Any,
value: dict) -> Any: value: dict, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
if not isinstance(value, dict): if not isinstance(value, dict):
raise TypeError(f'Expected a dict for {fieldpath};' raise TypeError(f'Expected a dict for {fieldpath};'
@ -848,7 +874,8 @@ class _Outputter:
raise TypeError(f'Got invalid key type {type(key)} for' raise TypeError(f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls};' f' dict key at \'{fieldpath}\' on {cls};'
f' expected {keyanntype}.') f' expected {keyanntype}.')
outval = self._process_value(cls, fieldpath, valanntype, val) outval = self._process_value(cls, fieldpath, valanntype, val,
ioattrs)
if self._create: if self._create:
assert out is not None assert out is not None
out[key] = outval out[key] = outval
@ -860,7 +887,8 @@ class _Outputter:
raise TypeError(f'Got invalid key type {type(key)} for' raise TypeError(f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls};' f' dict key at \'{fieldpath}\' on {cls};'
f' expected an int.') f' expected an int.')
outval = self._process_value(cls, fieldpath, valanntype, val) outval = self._process_value(cls, fieldpath, valanntype, val,
ioattrs)
if self._create: if self._create:
assert out is not None assert out is not None
out[str(key)] = outval out[str(key)] = outval
@ -871,7 +899,8 @@ class _Outputter:
raise TypeError(f'Got invalid key type {type(key)} for' raise TypeError(f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls};' f' dict key at \'{fieldpath}\' on {cls};'
f' expected a {keyanntype}.') f' expected a {keyanntype}.')
outval = self._process_value(cls, fieldpath, valanntype, val) outval = self._process_value(cls, fieldpath, valanntype, val,
ioattrs)
if self._create: if self._create:
assert out is not None assert out is not None
out[str(key.value)] = outval out[str(key.value)] = outval
@ -906,7 +935,7 @@ class _Inputter(Generic[T]):
return out return out
def _value_from_input(self, cls: Type, fieldpath: str, anntype: Any, def _value_from_input(self, cls: Type, fieldpath: str, anntype: Any,
value: Any) -> Any: value: Any, ioattrs: Optional[IOAttrs]) -> Any:
"""Convert an assigned value to what a dataclass field expects.""" """Convert an assigned value to what a dataclass field expects."""
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
@ -932,7 +961,7 @@ class _Inputter(Generic[T]):
] ]
assert len(childanntypes_l) == 1 assert len(childanntypes_l) == 1
return self._value_from_input(cls, fieldpath, childanntypes_l[0], return self._value_from_input(cls, fieldpath, childanntypes_l[0],
value) value, ioattrs)
# Everything below this point assumes the annotation type resolves # Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time). # to a concrete type. (This should have been verified at prep time).
@ -949,13 +978,15 @@ class _Inputter(Generic[T]):
if origin in {list, set}: if origin in {list, set}:
return self._sequence_from_input(cls, fieldpath, anntype, value, return self._sequence_from_input(cls, fieldpath, anntype, value,
origin) origin, ioattrs)
if origin is tuple: if origin is tuple:
return self._tuple_from_input(cls, fieldpath, anntype, value) return self._tuple_from_input(cls, fieldpath, anntype, value,
ioattrs)
if origin is dict: if origin is dict:
return self._dict_from_input(cls, fieldpath, anntype, value) return self._dict_from_input(cls, fieldpath, anntype, value,
ioattrs)
if dataclasses.is_dataclass(origin): if dataclasses.is_dataclass(origin):
return self._dataclass_from_input(origin, fieldpath, value) return self._dataclass_from_input(origin, fieldpath, value)
@ -964,7 +995,7 @@ class _Inputter(Generic[T]):
return enum_by_value(origin, value) return enum_by_value(origin, value)
if issubclass(origin, datetime.datetime): if issubclass(origin, datetime.datetime):
return self._datetime_from_input(cls, fieldpath, value) return self._datetime_from_input(cls, fieldpath, value, ioattrs)
if origin is bytes: if origin is bytes:
return self._bytes_from_input(origin, fieldpath, value) return self._bytes_from_input(origin, fieldpath, value)
@ -1039,12 +1070,12 @@ class _Inputter(Generic[T]):
else: else:
fieldname = field.name fieldname = field.name
anntype = prep.annotations[fieldname] anntype = prep.annotations[fieldname]
anntype, _ioattrs = _parse_annotated(anntype) anntype, ioattrs = _parse_annotated(anntype)
subfieldpath = (f'{fieldpath}.{fieldname}' subfieldpath = (f'{fieldpath}.{fieldname}'
if fieldpath else fieldname) if fieldpath else fieldname)
args[key] = self._value_from_input(cls, subfieldpath, anntype, args[key] = self._value_from_input(cls, subfieldpath, anntype,
value) value, ioattrs)
try: try:
out = cls(**args) out = cls(**args)
except Exception as exc: except Exception as exc:
@ -1056,8 +1087,9 @@ class _Inputter(Generic[T]):
return out return out
def _dict_from_input(self, cls: Type, fieldpath: str, anntype: Any, def _dict_from_input(self, cls: Type, fieldpath: str, anntype: Any,
value: Any) -> Any: value: Any, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
# pylint: disable=too-many-locals
if not isinstance(value, dict): if not isinstance(value, dict):
raise TypeError(f'Expected a dict for \'{fieldpath}\' on {cls};' raise TypeError(f'Expected a dict for \'{fieldpath}\' on {cls};'
@ -1092,7 +1124,7 @@ class _Inputter(Generic[T]):
f' dict key at \'{fieldpath}\' on {cls};' f' dict key at \'{fieldpath}\' on {cls};'
f' expected a str.') f' expected a str.')
out[key] = self._value_from_input(cls, fieldpath, out[key] = self._value_from_input(cls, fieldpath,
valanntype, val) valanntype, val, ioattrs)
# int keys are stored in json as str versions of themselves. # int keys are stored in json as str versions of themselves.
elif keyanntype is int: elif keyanntype is int:
@ -1110,7 +1142,7 @@ class _Inputter(Generic[T]):
f' dict key at \'{fieldpath}\' on {cls};' f' dict key at \'{fieldpath}\' on {cls};'
f' expected an int in string form.') from exc f' expected an int in string form.') from exc
out[keyint] = self._value_from_input( out[keyint] = self._value_from_input(
cls, fieldpath, valanntype, val) cls, fieldpath, valanntype, val, ioattrs)
elif issubclass(keyanntype, Enum): elif issubclass(keyanntype, Enum):
# In prep we verified that all these enums' values have # In prep we verified that all these enums' values have
@ -1129,7 +1161,7 @@ class _Inputter(Generic[T]):
f' expected a value corresponding to' f' expected a value corresponding to'
f' a {keyanntype}.') from exc f' a {keyanntype}.') from exc
out[enumval] = self._value_from_input( out[enumval] = self._value_from_input(
cls, fieldpath, valanntype, val) cls, fieldpath, valanntype, val, ioattrs)
else: else:
for key, val in value.items(): for key, val in value.items():
try: try:
@ -1141,7 +1173,7 @@ class _Inputter(Generic[T]):
f' expected {keyanntype} value (though' f' expected {keyanntype} value (though'
f' in string form).') from exc f' in string form).') from exc
out[enumval] = self._value_from_input( out[enumval] = self._value_from_input(
cls, fieldpath, valanntype, val) cls, fieldpath, valanntype, val, ioattrs)
else: else:
raise RuntimeError(f'Unhandled dict in-key-type {keyanntype}') raise RuntimeError(f'Unhandled dict in-key-type {keyanntype}')
@ -1149,7 +1181,8 @@ class _Inputter(Generic[T]):
return out return out
def _sequence_from_input(self, cls: Type, fieldpath: str, anntype: Any, def _sequence_from_input(self, cls: Type, fieldpath: str, anntype: Any,
value: Any, seqtype: Type) -> Any: value: Any, seqtype: Type,
ioattrs: Optional[IOAttrs]) -> Any:
# Because we are json-centric, we expect a list for all sequences. # Because we are json-centric, we expect a list for all sequences.
if type(value) is not list: if type(value) is not list:
@ -1171,11 +1204,11 @@ class _Inputter(Generic[T]):
assert len(childanntypes) == 1 assert len(childanntypes) == 1
childanntype = childanntypes[0] childanntype = childanntypes[0]
return seqtype( return seqtype(
self._value_from_input(cls, fieldpath, childanntype, i) self._value_from_input(cls, fieldpath, childanntype, i, ioattrs)
for i in value) for i in value)
def _datetime_from_input(self, cls: Type, fieldpath: str, def _datetime_from_input(self, cls: Type, fieldpath: str, value: Any,
value: Any) -> Any: ioattrs: Optional[IOAttrs]) -> Any:
# For firestore we expect a datetime object. # For firestore we expect a datetime object.
if self._codec is Codec.FIRESTORE: if self._codec is Codec.FIRESTORE:
@ -1199,11 +1232,14 @@ class _Inputter(Generic[T]):
raise TypeError( raise TypeError(
f'Invalid input value for "{fieldpath}" on "{cls}";' f'Invalid input value for "{fieldpath}" on "{cls}";'
f' expected a list of 7 ints.') f' expected a list of 7 ints.')
return datetime.datetime( # type: ignore out = datetime.datetime( # type: ignore
*value, tzinfo=datetime.timezone.utc) *value, tzinfo=datetime.timezone.utc)
if ioattrs is not None:
ioattrs.validate_datetime(out, fieldpath)
return out
def _tuple_from_input(self, cls: Type, fieldpath: str, anntype: Any, def _tuple_from_input(self, cls: Type, fieldpath: str, anntype: Any,
value: Any) -> Any: value: Any, ioattrs: Optional[IOAttrs]) -> Any:
out: List = [] out: List = []
@ -1235,7 +1271,7 @@ class _Inputter(Generic[T]):
else: else:
out.append( out.append(
self._value_from_input(cls, fieldpath, childanntype, self._value_from_input(cls, fieldpath, childanntype,
childval)) childval, ioattrs))
assert len(out) == len(childanntypes) assert len(out) == len(childanntypes)
return tuple(out) return tuple(out)

View File

@ -71,6 +71,25 @@ def utc_now() -> datetime.datetime:
return datetime.datetime.now(datetime.timezone.utc) return datetime.datetime.now(datetime.timezone.utc)
def utc_today() -> datetime.datetime:
"""Get offset-aware midnight in the utc time zone."""
now = datetime.datetime.now(datetime.timezone.utc)
return datetime.datetime(year=now.year,
month=now.month,
day=now.day,
tzinfo=now.tzinfo)
def utc_this_hour() -> datetime.datetime:
"""Get offset-aware beginning of the current hour in the utc time zone."""
now = datetime.datetime.now(datetime.timezone.utc)
return datetime.datetime(year=now.year,
month=now.month,
day=now.day,
hour=now.hour,
tzinfo=now.tzinfo)
def empty_weakref(objtype: Type[T]) -> ReferenceType[T]: def empty_weakref(objtype: Type[T]) -> ReferenceType[T]:
"""Return an invalidated weak-reference for the specified type.""" """Return an invalidated weak-reference for the specified type."""
# At runtime, all weakrefs are the same; our type arg is just # At runtime, all weakrefs are the same; our type arg is just