Added day and hour value limits to dataclassio

This commit is contained in:
Eric Froemling 2021-05-27 12:36:35 -05:00
parent dc7d7fab1d
commit 2f2ae11d4f
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
6 changed files with 138 additions and 40 deletions

View File

@ -2095,6 +2095,7 @@
<w>stickman</w>
<w>storable</w>
<w>storagename</w>
<w>storagenames</w>
<w>storecmd</w>
<w>storedhash</w>
<w>storeitemui</w>

View File

@ -941,6 +941,7 @@
<w>stephane</w>
<w>stepnum</w>
<w>stepsize</w>
<w>storagenames</w>
<w>storecmd</w>
<w>strcasecmp</w>
<w>strchr</w>

View File

@ -1,5 +1,5 @@
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
<h4><em>last updated on 2021-05-26 for Ballistica version 1.6.4 build 20369</em></h4>
<h4><em>last updated on 2021-05-27 for Ballistica version 1.6.4 build 20369</em></h4>
<p>This page documents the Python classes and functions in the 'ba' module,
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
<hr>

View File

@ -683,3 +683,44 @@ def test_field_storage_path_capture() -> None:
with pytest.raises(AttributeError):
assert namecap.nonexistent.barf == 's.barf'
def test_datetime_limits() -> None:
"""Test limiting datetime values in various ways."""
from efro.util import utc_today, utc_this_hour
@ioprepped
@dataclass
class _TestClass:
tval: Annotated[datetime.datetime, IOAttrs(whole_hours=True)]
# Check whole-hour limit when validating/exporting.
obj = _TestClass(tval=utc_this_hour() + datetime.timedelta(minutes=1))
with pytest.raises(ValueError):
dataclass_validate(obj)
obj.tval = utc_this_hour()
dataclass_validate(obj)
# Check whole-days limit when importing.
out = dataclass_to_dict(obj)
out['tval'][-1] += 1
with pytest.raises(ValueError):
dataclass_from_dict(_TestClass, out)
# Check whole-days limit when validating/exporting.
@ioprepped
@dataclass
class _TestClass2:
tval: Annotated[datetime.datetime, IOAttrs(whole_days=True)]
obj2 = _TestClass2(tval=utc_today() + datetime.timedelta(hours=1))
with pytest.raises(ValueError):
dataclass_validate(obj2)
obj2.tval = utc_today()
dataclass_validate(obj2)
# Check whole-days limit when importing.
out = dataclass_to_dict(obj2)
out['tval'][-1] += 1
with pytest.raises(ValueError):
dataclass_from_dict(_TestClass2, out)

View File

@ -72,9 +72,15 @@ class Codec(Enum):
class IOAttrs:
"""For specifying io behavior in annotations."""
def __init__(self, storagename: str = None, store_default: bool = True):
def __init__(self,
storagename: str = None,
store_default: bool = True,
whole_days: bool = False,
whole_hours: bool = False):
self.storagename = storagename
self.store_default = store_default
self.whole_days = whole_days
self.whole_hours = whole_hours
def validate_for_field(self, cls: Type, field: dataclasses.Field) -> None:
"""Ensure the IOAttrs instance is ok to use with the provided field."""
@ -89,13 +95,27 @@ class IOAttrs:
f' neither a default nor a default_factory;'
f' store_default=False cannot be set for it.')
def validate_datetime(self, value: datetime.datetime,
fieldpath: str) -> None:
"""Ensure a datetime value meets our value requirements."""
if self.whole_days:
if any(x != 0 for x in (value.hour, value.minute, value.second,
value.microsecond)):
raise ValueError(
f'Value {value} at {fieldpath} is not a whole day.')
if self.whole_hours:
if any(x != 0
for x in (value.minute, value.second, value.microsecond)):
raise ValueError(f'Value {value} at {fieldpath}'
f' is not a whole hour.')
class FieldStoragePathCapture:
"""Utility for obtaining dataclass storage paths in a type safe way.
Given dataclass instance foo, FieldStoragePathCapture(foo).bar.eep
will return 'bar.eep' (or something like 'b.e' if storagenames are
overrridden). This can be combined with type-checking tricks that
overridden). This can be combined with type-checking tricks that
return foo in the type-checker's eyes while returning
FieldStoragePathCapture(foo) at runtime in order to grant a measure
of type safety to specifying field paths for things such as db
@ -627,7 +647,8 @@ class _Outputter:
f' store_default=False cannot be set for it.'
f' (AND THIS SHOULD HAVE BEEN CAUGHT IN PREP!)')
outvalue = self._process_value(cls, subfieldpath, anntype, value)
outvalue = self._process_value(cls, subfieldpath, anntype, value,
ioattrs)
if self._create:
assert out is not None
storagename = (fieldname if
@ -648,7 +669,7 @@ class _Outputter:
return out
def _process_value(self, cls: Type, fieldpath: str, anntype: Any,
value: Any) -> Any:
value: Any, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
@ -674,7 +695,7 @@ class _Outputter:
]
assert len(childanntypes_l) == 1
return self._process_value(cls, fieldpath, childanntypes_l[0],
value)
value, ioattrs)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
@ -704,11 +725,12 @@ class _Outputter:
f' {len(childanntypes)}.')
if self._create:
return [
self._process_value(cls, fieldpath, childanntypes[i], x)
for i, x in enumerate(value)
self._process_value(cls, fieldpath, childanntypes[i], x,
ioattrs) for i, x in enumerate(value)
]
for i, x in enumerate(value):
self._process_value(cls, fieldpath, childanntypes[i], x)
self._process_value(cls, fieldpath, childanntypes[i], x,
ioattrs)
return None
if origin is list:
@ -731,11 +753,12 @@ class _Outputter:
assert len(childanntypes) == 1
if self._create:
return [
self._process_value(cls, fieldpath, childanntypes[0], x)
for x in value
self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs) for x in value
]
for x in value:
self._process_value(cls, fieldpath, childanntypes[0], x)
self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs)
return None
if origin is set:
@ -759,15 +782,16 @@ class _Outputter:
# Note: we output json-friendly values so this becomes
# a list.
return [
self._process_value(cls, fieldpath, childanntypes[0], x)
for x in value
self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs) for x in value
]
for x in value:
self._process_value(cls, fieldpath, childanntypes[0], x)
self._process_value(cls, fieldpath, childanntypes[0], x,
ioattrs)
return None
if origin is dict:
return self._process_dict(cls, fieldpath, anntype, value)
return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
if dataclasses.is_dataclass(origin):
if not isinstance(value, origin):
@ -788,6 +812,8 @@ class _Outputter:
raise TypeError(f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.')
_ensure_datetime_is_timezone_aware(value)
if ioattrs is not None:
ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE:
return value
assert self._codec is Codec.JSON
@ -819,7 +845,7 @@ class _Outputter:
return value
def _process_dict(self, cls: Type, fieldpath: str, anntype: Any,
value: dict) -> Any:
value: dict, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-branches
if not isinstance(value, dict):
raise TypeError(f'Expected a dict for {fieldpath};'
@ -848,7 +874,8 @@ class _Outputter:
raise TypeError(f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls};'
f' expected {keyanntype}.')
outval = self._process_value(cls, fieldpath, valanntype, val)
outval = self._process_value(cls, fieldpath, valanntype, val,
ioattrs)
if self._create:
assert out is not None
out[key] = outval
@ -860,7 +887,8 @@ class _Outputter:
raise TypeError(f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls};'
f' expected an int.')
outval = self._process_value(cls, fieldpath, valanntype, val)
outval = self._process_value(cls, fieldpath, valanntype, val,
ioattrs)
if self._create:
assert out is not None
out[str(key)] = outval
@ -871,7 +899,8 @@ class _Outputter:
raise TypeError(f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls};'
f' expected a {keyanntype}.')
outval = self._process_value(cls, fieldpath, valanntype, val)
outval = self._process_value(cls, fieldpath, valanntype, val,
ioattrs)
if self._create:
assert out is not None
out[str(key.value)] = outval
@ -906,7 +935,7 @@ class _Inputter(Generic[T]):
return out
def _value_from_input(self, cls: Type, fieldpath: str, anntype: Any,
value: Any) -> Any:
value: Any, ioattrs: Optional[IOAttrs]) -> Any:
"""Convert an assigned value to what a dataclass field expects."""
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
@ -932,7 +961,7 @@ class _Inputter(Generic[T]):
]
assert len(childanntypes_l) == 1
return self._value_from_input(cls, fieldpath, childanntypes_l[0],
value)
value, ioattrs)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
@ -949,13 +978,15 @@ class _Inputter(Generic[T]):
if origin in {list, set}:
return self._sequence_from_input(cls, fieldpath, anntype, value,
origin)
origin, ioattrs)
if origin is tuple:
return self._tuple_from_input(cls, fieldpath, anntype, value)
return self._tuple_from_input(cls, fieldpath, anntype, value,
ioattrs)
if origin is dict:
return self._dict_from_input(cls, fieldpath, anntype, value)
return self._dict_from_input(cls, fieldpath, anntype, value,
ioattrs)
if dataclasses.is_dataclass(origin):
return self._dataclass_from_input(origin, fieldpath, value)
@ -964,7 +995,7 @@ class _Inputter(Generic[T]):
return enum_by_value(origin, value)
if issubclass(origin, datetime.datetime):
return self._datetime_from_input(cls, fieldpath, value)
return self._datetime_from_input(cls, fieldpath, value, ioattrs)
if origin is bytes:
return self._bytes_from_input(origin, fieldpath, value)
@ -1039,12 +1070,12 @@ class _Inputter(Generic[T]):
else:
fieldname = field.name
anntype = prep.annotations[fieldname]
anntype, _ioattrs = _parse_annotated(anntype)
anntype, ioattrs = _parse_annotated(anntype)
subfieldpath = (f'{fieldpath}.{fieldname}'
if fieldpath else fieldname)
args[key] = self._value_from_input(cls, subfieldpath, anntype,
value)
value, ioattrs)
try:
out = cls(**args)
except Exception as exc:
@ -1056,8 +1087,9 @@ class _Inputter(Generic[T]):
return out
def _dict_from_input(self, cls: Type, fieldpath: str, anntype: Any,
value: Any) -> Any:
value: Any, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-branches
# pylint: disable=too-many-locals
if not isinstance(value, dict):
raise TypeError(f'Expected a dict for \'{fieldpath}\' on {cls};'
@ -1092,7 +1124,7 @@ class _Inputter(Generic[T]):
f' dict key at \'{fieldpath}\' on {cls};'
f' expected a str.')
out[key] = self._value_from_input(cls, fieldpath,
valanntype, val)
valanntype, val, ioattrs)
# int keys are stored in json as str versions of themselves.
elif keyanntype is int:
@ -1110,7 +1142,7 @@ class _Inputter(Generic[T]):
f' dict key at \'{fieldpath}\' on {cls};'
f' expected an int in string form.') from exc
out[keyint] = self._value_from_input(
cls, fieldpath, valanntype, val)
cls, fieldpath, valanntype, val, ioattrs)
elif issubclass(keyanntype, Enum):
# In prep we verified that all these enums' values have
@ -1129,7 +1161,7 @@ class _Inputter(Generic[T]):
f' expected a value corresponding to'
f' a {keyanntype}.') from exc
out[enumval] = self._value_from_input(
cls, fieldpath, valanntype, val)
cls, fieldpath, valanntype, val, ioattrs)
else:
for key, val in value.items():
try:
@ -1141,7 +1173,7 @@ class _Inputter(Generic[T]):
f' expected {keyanntype} value (though'
f' in string form).') from exc
out[enumval] = self._value_from_input(
cls, fieldpath, valanntype, val)
cls, fieldpath, valanntype, val, ioattrs)
else:
raise RuntimeError(f'Unhandled dict in-key-type {keyanntype}')
@ -1149,7 +1181,8 @@ class _Inputter(Generic[T]):
return out
def _sequence_from_input(self, cls: Type, fieldpath: str, anntype: Any,
value: Any, seqtype: Type) -> Any:
value: Any, seqtype: Type,
ioattrs: Optional[IOAttrs]) -> Any:
# Because we are json-centric, we expect a list for all sequences.
if type(value) is not list:
@ -1171,11 +1204,11 @@ class _Inputter(Generic[T]):
assert len(childanntypes) == 1
childanntype = childanntypes[0]
return seqtype(
self._value_from_input(cls, fieldpath, childanntype, i)
self._value_from_input(cls, fieldpath, childanntype, i, ioattrs)
for i in value)
def _datetime_from_input(self, cls: Type, fieldpath: str,
value: Any) -> Any:
def _datetime_from_input(self, cls: Type, fieldpath: str, value: Any,
ioattrs: Optional[IOAttrs]) -> Any:
# For firestore we expect a datetime object.
if self._codec is Codec.FIRESTORE:
@ -1199,11 +1232,14 @@ class _Inputter(Generic[T]):
raise TypeError(
f'Invalid input value for "{fieldpath}" on "{cls}";'
f' expected a list of 7 ints.')
return datetime.datetime( # type: ignore
out = datetime.datetime( # type: ignore
*value, tzinfo=datetime.timezone.utc)
if ioattrs is not None:
ioattrs.validate_datetime(out, fieldpath)
return out
def _tuple_from_input(self, cls: Type, fieldpath: str, anntype: Any,
value: Any) -> Any:
value: Any, ioattrs: Optional[IOAttrs]) -> Any:
out: List = []
@ -1235,7 +1271,7 @@ class _Inputter(Generic[T]):
else:
out.append(
self._value_from_input(cls, fieldpath, childanntype,
childval))
childval, ioattrs))
assert len(out) == len(childanntypes)
return tuple(out)

View File

@ -71,6 +71,25 @@ def utc_now() -> datetime.datetime:
return datetime.datetime.now(datetime.timezone.utc)
def utc_today() -> datetime.datetime:
"""Get offset-aware midnight in the utc time zone."""
now = datetime.datetime.now(datetime.timezone.utc)
return datetime.datetime(year=now.year,
month=now.month,
day=now.day,
tzinfo=now.tzinfo)
def utc_this_hour() -> datetime.datetime:
"""Get offset-aware beginning of the current hour in the utc time zone."""
now = datetime.datetime.now(datetime.timezone.utc)
return datetime.datetime(year=now.year,
month=now.month,
day=now.day,
hour=now.hour,
tzinfo=now.tzinfo)
def empty_weakref(objtype: Type[T]) -> ReferenceType[T]:
"""Return an invalidated weak-reference for the specified type."""
# At runtime, all weakrefs are the same; our type arg is just