diff --git a/.idea/dictionaries/ericf.xml b/.idea/dictionaries/ericf.xml
index 27aacacf..bca437a6 100644
--- a/.idea/dictionaries/ericf.xml
+++ b/.idea/dictionaries/ericf.xml
@@ -2095,6 +2095,7 @@
stickman
storable
storagename
+ storagenames
storecmd
storedhash
storeitemui
diff --git a/ballisticacore-cmake/.idea/dictionaries/ericf.xml b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
index e6333108..9701912b 100644
--- a/ballisticacore-cmake/.idea/dictionaries/ericf.xml
+++ b/ballisticacore-cmake/.idea/dictionaries/ericf.xml
@@ -941,6 +941,7 @@
stephane
stepnum
stepsize
+ storagenames
storecmd
strcasecmp
strchr
diff --git a/docs/ba_module.md b/docs/ba_module.md
index 602bcd0c..8d78ebf4 100644
--- a/docs/ba_module.md
+++ b/docs/ba_module.md
@@ -1,5 +1,5 @@
-
last updated on 2021-05-26 for Ballistica version 1.6.4 build 20369
+last updated on 2021-05-27 for Ballistica version 1.6.4 build 20369
This page documents the Python classes and functions in the 'ba' module,
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please let me know. Happy modding!
diff --git a/tests/test_efro/test_dataclassio.py b/tests/test_efro/test_dataclassio.py
index 8a5bbc9b..d1ce2780 100644
--- a/tests/test_efro/test_dataclassio.py
+++ b/tests/test_efro/test_dataclassio.py
@@ -683,3 +683,44 @@ def test_field_storage_path_capture() -> None:
with pytest.raises(AttributeError):
assert namecap.nonexistent.barf == 's.barf'
+
+
+def test_datetime_limits() -> None:
+ """Test limiting datetime values in various ways."""
+ from efro.util import utc_today, utc_this_hour
+
+ @ioprepped
+ @dataclass
+ class _TestClass:
+ tval: Annotated[datetime.datetime, IOAttrs(whole_hours=True)]
+
+ # Check whole-hour limit when validating/exporting.
+ obj = _TestClass(tval=utc_this_hour() + datetime.timedelta(minutes=1))
+ with pytest.raises(ValueError):
+ dataclass_validate(obj)
+ obj.tval = utc_this_hour()
+ dataclass_validate(obj)
+
+ # Check whole-days limit when importing.
+ out = dataclass_to_dict(obj)
+ out['tval'][-1] += 1
+ with pytest.raises(ValueError):
+ dataclass_from_dict(_TestClass, out)
+
+ # Check whole-days limit when validating/exporting.
+ @ioprepped
+ @dataclass
+ class _TestClass2:
+ tval: Annotated[datetime.datetime, IOAttrs(whole_days=True)]
+
+ obj2 = _TestClass2(tval=utc_today() + datetime.timedelta(hours=1))
+ with pytest.raises(ValueError):
+ dataclass_validate(obj2)
+ obj2.tval = utc_today()
+ dataclass_validate(obj2)
+
+ # Check whole-days limit when importing.
+ out = dataclass_to_dict(obj2)
+ out['tval'][-1] += 1
+ with pytest.raises(ValueError):
+ dataclass_from_dict(_TestClass2, out)
diff --git a/tools/efro/dataclassio.py b/tools/efro/dataclassio.py
index a7d05337..fd829f58 100644
--- a/tools/efro/dataclassio.py
+++ b/tools/efro/dataclassio.py
@@ -72,9 +72,15 @@ class Codec(Enum):
class IOAttrs:
"""For specifying io behavior in annotations."""
- def __init__(self, storagename: str = None, store_default: bool = True):
+ def __init__(self,
+ storagename: str = None,
+ store_default: bool = True,
+ whole_days: bool = False,
+ whole_hours: bool = False):
self.storagename = storagename
self.store_default = store_default
+ self.whole_days = whole_days
+ self.whole_hours = whole_hours
def validate_for_field(self, cls: Type, field: dataclasses.Field) -> None:
"""Ensure the IOAttrs instance is ok to use with the provided field."""
@@ -89,13 +95,27 @@ class IOAttrs:
f' neither a default nor a default_factory;'
f' store_default=False cannot be set for it.')
+ def validate_datetime(self, value: datetime.datetime,
+ fieldpath: str) -> None:
+ """Ensure a datetime value meets our value requirements."""
+ if self.whole_days:
+ if any(x != 0 for x in (value.hour, value.minute, value.second,
+ value.microsecond)):
+ raise ValueError(
+ f'Value {value} at {fieldpath} is not a whole day.')
+ if self.whole_hours:
+ if any(x != 0
+ for x in (value.minute, value.second, value.microsecond)):
+ raise ValueError(f'Value {value} at {fieldpath}'
+ f' is not a whole hour.')
+
class FieldStoragePathCapture:
"""Utility for obtaining dataclass storage paths in a type safe way.
Given dataclass instance foo, FieldStoragePathCapture(foo).bar.eep
will return 'bar.eep' (or something like 'b.e' if storagenames are
- overrridden). This can be combined with type-checking tricks that
+ overridden). This can be combined with type-checking tricks that
return foo in the type-checker's eyes while returning
FieldStoragePathCapture(foo) at runtime in order to grant a measure
of type safety to specifying field paths for things such as db
@@ -627,7 +647,8 @@ class _Outputter:
f' store_default=False cannot be set for it.'
f' (AND THIS SHOULD HAVE BEEN CAUGHT IN PREP!)')
- outvalue = self._process_value(cls, subfieldpath, anntype, value)
+ outvalue = self._process_value(cls, subfieldpath, anntype, value,
+ ioattrs)
if self._create:
assert out is not None
storagename = (fieldname if
@@ -648,7 +669,7 @@ class _Outputter:
return out
def _process_value(self, cls: Type, fieldpath: str, anntype: Any,
- value: Any) -> Any:
+ value: Any, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
@@ -674,7 +695,7 @@ class _Outputter:
]
assert len(childanntypes_l) == 1
return self._process_value(cls, fieldpath, childanntypes_l[0],
- value)
+ value, ioattrs)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
@@ -704,11 +725,12 @@ class _Outputter:
f' {len(childanntypes)}.')
if self._create:
return [
- self._process_value(cls, fieldpath, childanntypes[i], x)
- for i, x in enumerate(value)
+ self._process_value(cls, fieldpath, childanntypes[i], x,
+ ioattrs) for i, x in enumerate(value)
]
for i, x in enumerate(value):
- self._process_value(cls, fieldpath, childanntypes[i], x)
+ self._process_value(cls, fieldpath, childanntypes[i], x,
+ ioattrs)
return None
if origin is list:
@@ -731,11 +753,12 @@ class _Outputter:
assert len(childanntypes) == 1
if self._create:
return [
- self._process_value(cls, fieldpath, childanntypes[0], x)
- for x in value
+ self._process_value(cls, fieldpath, childanntypes[0], x,
+ ioattrs) for x in value
]
for x in value:
- self._process_value(cls, fieldpath, childanntypes[0], x)
+ self._process_value(cls, fieldpath, childanntypes[0], x,
+ ioattrs)
return None
if origin is set:
@@ -759,15 +782,16 @@ class _Outputter:
# Note: we output json-friendly values so this becomes
# a list.
return [
- self._process_value(cls, fieldpath, childanntypes[0], x)
- for x in value
+ self._process_value(cls, fieldpath, childanntypes[0], x,
+ ioattrs) for x in value
]
for x in value:
- self._process_value(cls, fieldpath, childanntypes[0], x)
+ self._process_value(cls, fieldpath, childanntypes[0], x,
+ ioattrs)
return None
if origin is dict:
- return self._process_dict(cls, fieldpath, anntype, value)
+ return self._process_dict(cls, fieldpath, anntype, value, ioattrs)
if dataclasses.is_dataclass(origin):
if not isinstance(value, origin):
@@ -788,6 +812,8 @@ class _Outputter:
raise TypeError(f'Expected a {origin} for {fieldpath};'
f' found a {type(value)}.')
_ensure_datetime_is_timezone_aware(value)
+ if ioattrs is not None:
+ ioattrs.validate_datetime(value, fieldpath)
if self._codec is Codec.FIRESTORE:
return value
assert self._codec is Codec.JSON
@@ -819,7 +845,7 @@ class _Outputter:
return value
def _process_dict(self, cls: Type, fieldpath: str, anntype: Any,
- value: dict) -> Any:
+ value: dict, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-branches
if not isinstance(value, dict):
raise TypeError(f'Expected a dict for {fieldpath};'
@@ -848,7 +874,8 @@ class _Outputter:
raise TypeError(f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls};'
f' expected {keyanntype}.')
- outval = self._process_value(cls, fieldpath, valanntype, val)
+ outval = self._process_value(cls, fieldpath, valanntype, val,
+ ioattrs)
if self._create:
assert out is not None
out[key] = outval
@@ -860,7 +887,8 @@ class _Outputter:
raise TypeError(f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls};'
f' expected an int.')
- outval = self._process_value(cls, fieldpath, valanntype, val)
+ outval = self._process_value(cls, fieldpath, valanntype, val,
+ ioattrs)
if self._create:
assert out is not None
out[str(key)] = outval
@@ -871,7 +899,8 @@ class _Outputter:
raise TypeError(f'Got invalid key type {type(key)} for'
f' dict key at \'{fieldpath}\' on {cls};'
f' expected a {keyanntype}.')
- outval = self._process_value(cls, fieldpath, valanntype, val)
+ outval = self._process_value(cls, fieldpath, valanntype, val,
+ ioattrs)
if self._create:
assert out is not None
out[str(key.value)] = outval
@@ -906,7 +935,7 @@ class _Inputter(Generic[T]):
return out
def _value_from_input(self, cls: Type, fieldpath: str, anntype: Any,
- value: Any) -> Any:
+ value: Any, ioattrs: Optional[IOAttrs]) -> Any:
"""Convert an assigned value to what a dataclass field expects."""
# pylint: disable=too-many-return-statements
# pylint: disable=too-many-branches
@@ -932,7 +961,7 @@ class _Inputter(Generic[T]):
]
assert len(childanntypes_l) == 1
return self._value_from_input(cls, fieldpath, childanntypes_l[0],
- value)
+ value, ioattrs)
# Everything below this point assumes the annotation type resolves
# to a concrete type. (This should have been verified at prep time).
@@ -949,13 +978,15 @@ class _Inputter(Generic[T]):
if origin in {list, set}:
return self._sequence_from_input(cls, fieldpath, anntype, value,
- origin)
+ origin, ioattrs)
if origin is tuple:
- return self._tuple_from_input(cls, fieldpath, anntype, value)
+ return self._tuple_from_input(cls, fieldpath, anntype, value,
+ ioattrs)
if origin is dict:
- return self._dict_from_input(cls, fieldpath, anntype, value)
+ return self._dict_from_input(cls, fieldpath, anntype, value,
+ ioattrs)
if dataclasses.is_dataclass(origin):
return self._dataclass_from_input(origin, fieldpath, value)
@@ -964,7 +995,7 @@ class _Inputter(Generic[T]):
return enum_by_value(origin, value)
if issubclass(origin, datetime.datetime):
- return self._datetime_from_input(cls, fieldpath, value)
+ return self._datetime_from_input(cls, fieldpath, value, ioattrs)
if origin is bytes:
return self._bytes_from_input(origin, fieldpath, value)
@@ -1039,12 +1070,12 @@ class _Inputter(Generic[T]):
else:
fieldname = field.name
anntype = prep.annotations[fieldname]
- anntype, _ioattrs = _parse_annotated(anntype)
+ anntype, ioattrs = _parse_annotated(anntype)
subfieldpath = (f'{fieldpath}.{fieldname}'
if fieldpath else fieldname)
args[key] = self._value_from_input(cls, subfieldpath, anntype,
- value)
+ value, ioattrs)
try:
out = cls(**args)
except Exception as exc:
@@ -1056,8 +1087,9 @@ class _Inputter(Generic[T]):
return out
def _dict_from_input(self, cls: Type, fieldpath: str, anntype: Any,
- value: Any) -> Any:
+ value: Any, ioattrs: Optional[IOAttrs]) -> Any:
# pylint: disable=too-many-branches
+ # pylint: disable=too-many-locals
if not isinstance(value, dict):
raise TypeError(f'Expected a dict for \'{fieldpath}\' on {cls};'
@@ -1092,7 +1124,7 @@ class _Inputter(Generic[T]):
f' dict key at \'{fieldpath}\' on {cls};'
f' expected a str.')
out[key] = self._value_from_input(cls, fieldpath,
- valanntype, val)
+ valanntype, val, ioattrs)
# int keys are stored in json as str versions of themselves.
elif keyanntype is int:
@@ -1110,7 +1142,7 @@ class _Inputter(Generic[T]):
f' dict key at \'{fieldpath}\' on {cls};'
f' expected an int in string form.') from exc
out[keyint] = self._value_from_input(
- cls, fieldpath, valanntype, val)
+ cls, fieldpath, valanntype, val, ioattrs)
elif issubclass(keyanntype, Enum):
# In prep we verified that all these enums' values have
@@ -1129,7 +1161,7 @@ class _Inputter(Generic[T]):
f' expected a value corresponding to'
f' a {keyanntype}.') from exc
out[enumval] = self._value_from_input(
- cls, fieldpath, valanntype, val)
+ cls, fieldpath, valanntype, val, ioattrs)
else:
for key, val in value.items():
try:
@@ -1141,7 +1173,7 @@ class _Inputter(Generic[T]):
f' expected {keyanntype} value (though'
f' in string form).') from exc
out[enumval] = self._value_from_input(
- cls, fieldpath, valanntype, val)
+ cls, fieldpath, valanntype, val, ioattrs)
else:
raise RuntimeError(f'Unhandled dict in-key-type {keyanntype}')
@@ -1149,7 +1181,8 @@ class _Inputter(Generic[T]):
return out
def _sequence_from_input(self, cls: Type, fieldpath: str, anntype: Any,
- value: Any, seqtype: Type) -> Any:
+ value: Any, seqtype: Type,
+ ioattrs: Optional[IOAttrs]) -> Any:
# Because we are json-centric, we expect a list for all sequences.
if type(value) is not list:
@@ -1171,11 +1204,11 @@ class _Inputter(Generic[T]):
assert len(childanntypes) == 1
childanntype = childanntypes[0]
return seqtype(
- self._value_from_input(cls, fieldpath, childanntype, i)
+ self._value_from_input(cls, fieldpath, childanntype, i, ioattrs)
for i in value)
- def _datetime_from_input(self, cls: Type, fieldpath: str,
- value: Any) -> Any:
+ def _datetime_from_input(self, cls: Type, fieldpath: str, value: Any,
+ ioattrs: Optional[IOAttrs]) -> Any:
# For firestore we expect a datetime object.
if self._codec is Codec.FIRESTORE:
@@ -1199,11 +1232,14 @@ class _Inputter(Generic[T]):
raise TypeError(
f'Invalid input value for "{fieldpath}" on "{cls}";'
f' expected a list of 7 ints.')
- return datetime.datetime( # type: ignore
+ out = datetime.datetime( # type: ignore
*value, tzinfo=datetime.timezone.utc)
+ if ioattrs is not None:
+ ioattrs.validate_datetime(out, fieldpath)
+ return out
def _tuple_from_input(self, cls: Type, fieldpath: str, anntype: Any,
- value: Any) -> Any:
+ value: Any, ioattrs: Optional[IOAttrs]) -> Any:
out: List = []
@@ -1235,7 +1271,7 @@ class _Inputter(Generic[T]):
else:
out.append(
self._value_from_input(cls, fieldpath, childanntype,
- childval))
+ childval, ioattrs))
assert len(out) == len(childanntypes)
return tuple(out)
diff --git a/tools/efro/util.py b/tools/efro/util.py
index 9a0a8c42..162d2aab 100644
--- a/tools/efro/util.py
+++ b/tools/efro/util.py
@@ -71,6 +71,25 @@ def utc_now() -> datetime.datetime:
return datetime.datetime.now(datetime.timezone.utc)
+def utc_today() -> datetime.datetime:
+ """Get offset-aware midnight in the utc time zone."""
+ now = datetime.datetime.now(datetime.timezone.utc)
+ return datetime.datetime(year=now.year,
+ month=now.month,
+ day=now.day,
+ tzinfo=now.tzinfo)
+
+
+def utc_this_hour() -> datetime.datetime:
+ """Get offset-aware beginning of the current hour in the utc time zone."""
+ now = datetime.datetime.now(datetime.timezone.utc)
+ return datetime.datetime(year=now.year,
+ month=now.month,
+ day=now.day,
+ hour=now.hour,
+ tzinfo=now.tzinfo)
+
+
def empty_weakref(objtype: Type[T]) -> ReferenceType[T]:
"""Return an invalidated weak-reference for the specified type."""
# At runtime, all weakrefs are the same; our type arg is just