From 4aa178e337e6a9865d3ec535db51a2edecb5605e Mon Sep 17 00:00:00 2001
From: Eric Froemling This page documents the Python classes and functions in the 'ba' module,
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please let me know. Happy modding!last updated on 2021-01-08 for Ballistica version 1.5.30 build 20266
+last updated on 2021-01-14 for Ballistica version 1.5.30 build 20267
@@ -86,7 +86,6 @@
enum_by_value(cls: Type[ET], value: Any) -> ET
- -Create an enum from a value.
- -Category: General Utility Functions
- -This is basically the same as doing 'obj = EnumType(value)' except -that it works around an issue where a reference loop is created -if an exception is thrown due to an invalid value. Since we disable -the cyclic garbage collector for most of the time, such loops can lead -to our objects sticking around longer than we want. -This issue has been submitted to Python as a bug so hopefully we can -remove this eventually if it gets fixed: https://bugs.python.org/issue42248
-existing(obj: Optional[ExistableType]) -> Optional[ExistableType]
diff --git a/tests/test_efro/test_dataclasses.py b/tests/test_efro/test_dataclasses.py index 2185cd2b..952ee23d 100644 --- a/tests/test_efro/test_dataclasses.py +++ b/tests/test_efro/test_dataclasses.py @@ -4,19 +4,33 @@ from __future__ import annotations +from enum import Enum from dataclasses import dataclass, field from typing import TYPE_CHECKING import pytest -from efro.dataclasses import dataclass_assign, dataclass_validate +from efro.dataclasses import (dataclass_validate, dataclass_from_dict, + dataclass_to_dict) if TYPE_CHECKING: - from typing import Optional, List + from typing import Optional, List, Set + + +class _EnumTest(Enum): + TEST1 = 'test1' + TEST2 = 'test2' + + +@dataclass +class _NestedClass: + ival: int = 0 + sval: str = 'foo' def test_assign() -> None: """Testing various assignments.""" + # pylint: disable=too-many-statements @dataclass @@ -25,120 +39,181 @@ def test_assign() -> None: sval: str = '' bval: bool = True fval: float = 1.0 + nval: _NestedClass = field(default_factory=_NestedClass) + enval: _EnumTest = _EnumTest.TEST1 oival: Optional[int] = None osval: Optional[str] = None obval: Optional[bool] = None ofval: Optional[float] = None + oenval: Optional[_EnumTest] = _EnumTest.TEST1 lsval: List[str] = field(default_factory=list) lival: List[int] = field(default_factory=list) lbval: List[bool] = field(default_factory=list) lfval: List[float] = field(default_factory=list) - - tclass = _TestClass() + lenval: List[_EnumTest] = field(default_factory=list) + ssval: Set[str] = field(default_factory=set) class _TestClass2: pass - tclass2 = _TestClass2() - - # Arg types: + # Attempting to use with non-dataclass should fail. with pytest.raises(TypeError): - dataclass_assign(tclass2, {}) + dataclass_from_dict(_TestClass2, {}) + # Attempting to pass non-dicts should fail. with pytest.raises(TypeError): - dataclass_assign(tclass, []) # type: ignore + dataclass_from_dict(_TestClass, []) # type: ignore + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, None) # type: ignore - # Invalid attrs. + # Passing an attr not in the dataclass should fail. with pytest.raises(AttributeError): - dataclass_assign(tclass, {'nonexistent': 'foo'}) + dataclass_from_dict(_TestClass, {'nonexistent': 'foo'}) - # Correct types. - dataclass_assign( - tclass, { + # A dict containing *ALL* values should match what we + # get when creating a dataclass and then converting back + # to a dict. + dict1 = { + 'ival': 1, + 'sval': 'foo', + 'bval': True, + 'fval': 2.0, + 'nval': { 'ival': 1, - 'sval': 'foo', - 'bval': True, - 'fval': 2.0, - 'lsval': ['foo'], - 'lival': [10], - 'lbval': [False], - 'lfval': [1.0] - }) - dataclass_assign( - tclass, { - 'oival': None, - 'osval': None, - 'obval': None, - 'ofval': None, - 'lsval': [], - 'lival': [], - 'lbval': [], - 'lfval': [] - }) - dataclass_assign( - tclass, { - 'oival': 1, - 'osval': 'foo', - 'obval': True, - 'ofval': 2.0, - 'lsval': ['foo', 'bar', 'eep'], - 'lival': [10, 11, 12], - 'lbval': [False, True], - 'lfval': [1.0, 2.0, 3.0] - }) + 'sval': 'bar' + }, + 'enval': 'test1', + 'oival': 1, + 'osval': 'foo', + 'obval': True, + 'ofval': 1.0, + 'oenval': 'test2', + 'lsval': ['foo'], + 'lival': [10], + 'lbval': [False], + 'lfval': [1.0], + 'lenval': ['test1', 'test2'], + 'ssval': ['foo'] + } + dc1 = dataclass_from_dict(_TestClass, dict1) + assert dataclass_to_dict(dc1) == dict1 - # Type mismatches. - with pytest.raises(TypeError): - dataclass_assign(tclass, {'ival': 'foo'}) + # A few other assignment checks. + assert isinstance( + dataclass_from_dict( + _TestClass, { + 'oival': None, + 'osval': None, + 'obval': None, + 'ofval': None, + 'lsval': [], + 'lival': [], + 'lbval': [], + 'lfval': [], + 'ssval': [] + }), _TestClass) + assert isinstance( + dataclass_from_dict( + _TestClass, { + 'oival': 1, + 'osval': 'foo', + 'obval': True, + 'ofval': 2.0, + 'lsval': ['foo', 'bar', 'eep'], + 'lival': [10, 11, 12], + 'lbval': [False, True], + 'lfval': [1.0, 2.0, 3.0] + }), _TestClass) + # Attr assigns mismatched with their value types should fail. with pytest.raises(TypeError): - dataclass_assign(tclass, {'sval': 1}) + dataclass_from_dict(_TestClass, {'ival': 'foo'}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'sval': 1}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'bval': 2}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'oival': 'foo'}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'osval': 1}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'obval': 2}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'ofval': 'blah'}) + with pytest.raises(ValueError): + dataclass_from_dict(_TestClass, {'oenval': 'test3'}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'lsval': 'blah'}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'lsval': ['blah', None]}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'lsval': [1]}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'lsval': (1, )}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'lbval': [None]}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'lival': ['foo']}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'lfval': [True]}) + with pytest.raises(ValueError): + dataclass_from_dict(_TestClass, {'lenval': ['test1', 'test3']}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'ssval': [True]}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'ssval': {}}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'ssval': set()}) + # More subtle attr/type mismatches that should fail + # (we currently require EXACT type matches). with pytest.raises(TypeError): - dataclass_assign(tclass, {'bval': 2}) + dataclass_from_dict(_TestClass, {'ival': True}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'fval': 2}, coerce_to_float=False) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'bval': 1}) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'ofval': 1}, coerce_to_float=False) + with pytest.raises(TypeError): + dataclass_from_dict(_TestClass, {'lfval': [1]}, coerce_to_float=False) - with pytest.raises(TypeError): - dataclass_assign(tclass, {'oival': 'foo'}) - with pytest.raises(TypeError): - dataclass_assign(tclass, {'osval': 1}) +def test_coerce() -> None: + """Test value coercion.""" - with pytest.raises(TypeError): - dataclass_assign(tclass, {'obval': 2}) + @dataclass + class _TestClass: + ival: int = 0 + fval: float = 0.0 + # Float value present for int should never work. + obj = _TestClass() + # noinspection PyTypeHints + obj.ival = 1.0 # type: ignore with pytest.raises(TypeError): - dataclass_assign(tclass, {'ofval': 'blah'}) + dataclass_validate(obj, coerce_to_float=True) + with pytest.raises(TypeError): + dataclass_validate(obj, coerce_to_float=False) + # Int value present for float should work only with coerce on. + obj = _TestClass() + obj.fval = 1 + dataclass_validate(obj, coerce_to_float=True) with pytest.raises(TypeError): - dataclass_assign(tclass, {'lsval': 'blah'}) + dataclass_validate(obj, coerce_to_float=False) + # Likewise, passing in an int for a float field should work only + # with coerce on. + dataclass_from_dict(_TestClass, {'fval': 1}, coerce_to_float=True) with pytest.raises(TypeError): - dataclass_assign(tclass, {'lsval': [1]}) + dataclass_from_dict(_TestClass, {'fval': 1}, coerce_to_float=False) + # Passing in floats for an int field should never work. with pytest.raises(TypeError): - dataclass_assign(tclass, {'lbval': [None]}) - + dataclass_from_dict(_TestClass, {'ival': 1.0}, coerce_to_float=True) with pytest.raises(TypeError): - dataclass_assign(tclass, {'lival': ['foo']}) - - with pytest.raises(TypeError): - dataclass_assign(tclass, {'lfval': [True]}) - - # More subtle ones (we currently require EXACT type matches) - with pytest.raises(TypeError): - dataclass_assign(tclass, {'ival': True}) - - with pytest.raises(TypeError): - dataclass_assign(tclass, {'fval': 2}) - - with pytest.raises(TypeError): - dataclass_assign(tclass, {'bval': 1}) - - with pytest.raises(TypeError): - dataclass_assign(tclass, {'ofval': 1}) - - with pytest.raises(TypeError): - dataclass_assign(tclass, {'lfval': [1]}) + dataclass_from_dict(_TestClass, {'ival': 1.0}, coerce_to_float=False) def test_validate() -> None: @@ -159,10 +234,10 @@ def test_validate() -> None: tclass = _TestClass() dataclass_validate(tclass) - # No longer valid. + # No longer valid (without coerce) tclass.fval = 1 with pytest.raises(TypeError): - dataclass_validate(tclass) + dataclass_validate(tclass, coerce_to_float=False) # Should pass by default. tclass = _TestClass() diff --git a/tools/efro/dataclasses.py b/tools/efro/dataclasses.py index 4b20d74f..7c29fbd5 100644 --- a/tools/efro/dataclasses.py +++ b/tools/efro/dataclasses.py @@ -1,122 +1,295 @@ # Released under the MIT License. See LICENSE for details. # """Custom functionality for dealing with dataclasses.""" +# Note: We do lots of comparing of exact types here which is normally +# frowned upon (stuff like isinstance() is usually encouraged). +# pylint: disable=unidiomatic-typecheck + from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING +import inspect +from enum import Enum +from typing import TYPE_CHECKING, TypeVar, Generic + +from efro.util import enum_by_value if TYPE_CHECKING: - from typing import Any, Dict, Type, Tuple + from typing import Any, Dict, Type, Tuple, Optional -# For fields with these string types, we require a passed value's type -# to exactly match one of the tuple values to consider the assignment valid. -_SIMPLE_ASSIGN_TYPES: Dict[str, Tuple[Type, ...]] = { - 'int': (int, ), - 'str': (str, ), - 'bool': (bool, ), - 'float': (float, ), - 'Optional[int]': (int, type(None)), - 'Optional[str]': (str, type(None)), - 'Optional[bool]': (bool, type(None)), - 'Optional[float]': (float, type(None)), -} - -_LIST_ASSIGN_TYPES: Dict[str, Tuple[Type, ...]] = { - 'List[int]': (int, ), - 'List[str]': (str, ), - 'List[bool]': (bool, ), - 'List[float]': (float, ), +T = TypeVar('T') + +SIMPLE_NAMES_TO_TYPES: Dict[str, Type] = { + 'int': int, + 'bool': bool, + 'str': str, + 'float': float, } +SIMPLE_TYPES_TO_NAMES = {tp: nm for nm, tp in SIMPLE_NAMES_TO_TYPES.items()} -def dataclass_assign(instance: Any, values: Dict[str, Any]) -> None: - """Safely assign values from a dict to a dataclass instance. +def dataclass_to_dict(obj: Any, coerce_to_float: bool = True) -> dict: + """Given a dataclass object, emit a json-friendly dict. - A TypeError will be raised if types to not match the dataclass fields - or are unsupported by this function. Note that a limited number of - types are supported. More can be added as needed. + All values will be checked to ensure they match the types specified + on fields. Note that only a limited set of types is supported. - Exact types are strictly checked, so a bool cannot be passed for - an int field, an int can't be passed for a float, etc. - (can reexamine this strictness if it proves to be a problem) - - An AttributeError will be raised if attributes are passed which are - not present on the dataclass as fields. - - This function may add significant overhead compared to passing dict - values to a dataclass' constructor or other more direct methods, but - the increased safety checks may be worth the speed tradeoff in some - cases. + If coerce_to_float is True, integer values present on float typed fields + will be converted to floats in the dict output. If False, a TypeError + will be triggered. """ - _dataclass_validate(instance, values) - for key, value in values.items(): - setattr(instance, key, value) + + out = _Outputter(obj, create=True, coerce_to_float=coerce_to_float).run() + assert isinstance(out, dict) + return out -def dataclass_validate(instance: Any) -> None: - """Ensure values in a dataclass are correct types. +def dataclass_from_dict(cls: Type[T], + values: dict, + coerce_to_float: bool = True) -> T: + """Given a dict, instantiates a dataclass of the given type. - Note that this will always fail if a dataclass contains field types - not supported by this module. + The dict must be in the json-friendly format as emitted from + dataclass_to_dict. This means that sequence values such as tuples or + sets should be passed as lists, enums should be passed as their + associated values, and nested dataclasses should be passed as dicts. + + If coerce_to_float is True, int values passed for float typed fields + will be converted to float values. Otherwise a TypeError is raised. """ - _dataclass_validate(instance, dataclasses.asdict(instance)) + return _Inputter(cls, coerce_to_float=coerce_to_float).run(values) -def _dataclass_validate(instance: Any, values: Dict[str, Any]) -> None: - # pylint: disable=too-many-branches - if not dataclasses.is_dataclass(instance): - raise TypeError(f'Passed instance {instance} is not a dataclass.') - if not isinstance(values, dict): - raise TypeError("Expected a dict for 'values' arg.") - fields = dataclasses.fields(instance) - fieldsdict = {f.name: f for f in fields} - for key, value in values.items(): - if key not in fieldsdict: - raise AttributeError( - f"'{type(instance).__name__}' has no '{key}' field.") - field = fieldsdict[key] +def dataclass_validate(obj: Any, coerce_to_float: bool = True) -> None: + """Ensure that current values in a dataclass are the correct types.""" + _Outputter(obj, create=False, coerce_to_float=coerce_to_float).run() - # We expect to be operating under 'from __future__ import annotations' - # so field types should always be strings for us; not an actual types. - # Complain if we come across an actual type. - fieldtype: str = field.type # type: ignore - if not isinstance(fieldtype, str): - raise RuntimeError( - f'Dataclass {type(instance).__name__} seems to have' - f' been created without "from __future__ import annotations";' - f' those dataclasses are unsupported here.') - if fieldtype in _SIMPLE_ASSIGN_TYPES: - reqtypes = _SIMPLE_ASSIGN_TYPES[fieldtype] - valuetype = type(value) - if not any(valuetype is t for t in reqtypes): - if len(reqtypes) == 1: - expected = reqtypes[0].__name__ - else: - names = ', '.join(t.__name__ for t in reqtypes) - expected = f'Union[{names}]' - raise TypeError(f'Invalid value type for "{key}";' - f' expected "{expected}", got' - f' "{valuetype.__name__}".') +def _field_type_str(cls: Type, field: dataclasses.Field) -> str: + # We expect to be operating under 'from __future__ import annotations' + # so field types should always be strings for us; not actual types. + # (Can pull this check out once we get to Python 3.10) + typestr: str = field.type # type: ignore - elif fieldtype in _LIST_ASSIGN_TYPES: - reqtypes = _LIST_ASSIGN_TYPES[fieldtype] + if not isinstance(typestr, str): + raise RuntimeError( + f'Dataclass {cls.__name__} seems to have' + f' been created without "from __future__ import annotations";' + f' those dataclasses are unsupported here.') + return typestr + + +def _raise_type_error(fieldpath: str, valuetype: Type, + expected: Tuple[Type, ...]) -> None: + """Raise an error when a field value's type does not match expected.""" + assert isinstance(expected, tuple) + assert all(isinstance(e, type) for e in expected) + if len(expected) == 1: + expected_str = expected[0].__name__ + else: + names = ', '.join(t.__name__ for t in expected) + expected_str = f'Union[{names}]' + raise TypeError(f'Invalid value type for "{fieldpath}";' + f' expected "{expected_str}", got' + f' "{valuetype.__name__}".') + + +class _Outputter: + + def __init__(self, obj: Any, create: bool, coerce_to_float: bool) -> None: + self._obj = obj + self._create = create + self._coerce_to_float = coerce_to_float + + def run(self) -> Any: + """Do the thing.""" + return self._dataclass_to_output(self._obj, '') + + def _value_to_output(self, fieldpath: str, typestr: str, + value: Any) -> Any: + # pylint: disable=too-many-return-statements + # pylint: disable=too-many-branches + + # For simple flat types, look for exact matches: + simpletype = SIMPLE_NAMES_TO_TYPES.get(typestr) + if simpletype is not None: + if type(value) is not simpletype: + # Special case: if they want to coerce ints to floats, do so. + if (self._coerce_to_float and simpletype is float + and type(value) is int): + return float(value) if self._create else None + _raise_type_error(fieldpath, type(value), (simpletype, )) + return value + + if typestr.startswith('Optional[') and typestr.endswith(']'): + subtypestr = typestr[9:-1] + # Handle the 'None' case special and do the default otherwise. + if value is None: + return None + return self._value_to_output(fieldpath, subtypestr, value) + + if typestr.startswith('List[') and typestr.endswith(']'): + subtypestr = typestr[5:-1] if not isinstance(value, list): - raise TypeError( - f'Invalid value for "{key}";' - f' expected a list, got a "{type(value).__name__}"') - for subvalue in value: - subvaluetype = type(subvalue) - if not any(subvaluetype is t for t in reqtypes): - if len(reqtypes) == 1: - expected = reqtypes[0].__name__ - else: - names = ', '.join(t.__name__ for t in reqtypes) - expected = f'Union[{names}]' - raise TypeError(f'Invalid value type for "{key}";' - f' expected list of "{expected}", found' - f' "{subvaluetype.__name__}".') + raise TypeError(f'Expected a list for {fieldpath};' + f' found a {type(value)}') + if self._create: + return [ + self._value_to_output(fieldpath, subtypestr, x) + for x in value + ] + for x in value: + self._value_to_output(fieldpath, subtypestr, x) + return None - else: - raise TypeError(f'Field type "{fieldtype}" is unsupported here.') + if typestr.startswith('Set[') and typestr.endswith(']'): + subtypestr = typestr[4:-1] + if not isinstance(value, set): + raise TypeError(f'Expected a set for {fieldpath};' + f' found a {type(value)}') + if self._create: + # Note: we output json-friendly values so this becomes a list. + return [ + self._value_to_output(fieldpath, subtypestr, x) + for x in value + ] + for x in value: + self._value_to_output(fieldpath, subtypestr, x) + return None + + if dataclasses.is_dataclass(value): + return self._dataclass_to_output(value, fieldpath) + + if isinstance(value, Enum): + enumvalue = value.value + if type(enumvalue) not in SIMPLE_TYPES_TO_NAMES: + raise TypeError(f'Invalid enum value type {type(enumvalue)}' + f' for "{fieldpath}".') + return enumvalue + + raise TypeError( + f"Field '{fieldpath}' of type '{typestr}' is unsupported here.") + + def _dataclass_to_output(self, obj: Any, fieldpath: str) -> Any: + if not dataclasses.is_dataclass(obj): + raise TypeError(f'Passed obj {obj} is not a dataclass.') + fields = dataclasses.fields(obj) + out: Optional[Dict[str, Any]] = {} if self._create else None + + for field in fields: + fieldname = field.name + + if fieldpath: + subfieldpath = f'{fieldpath}.{fieldname}' + else: + subfieldpath = fieldname + typestr = _field_type_str(type(obj), field) + value = getattr(obj, fieldname) + outvalue = self._value_to_output(subfieldpath, typestr, value) + if self._create: + assert out is not None + out[fieldname] = outvalue + + return out + + +class _Inputter(Generic[T]): + + def __init__(self, cls: Type[T], coerce_to_float: bool): + self._cls = cls + self._coerce_to_float = coerce_to_float + + def run(self, values: dict) -> T: + """Do the thing.""" + return self._dataclass_from_input( # type: ignore + self._cls, '', values) + + def _value_from_input(self, cls: Type, fieldpath: str, typestr: str, + value: Any) -> Any: + """Convert an assigned value to what a dataclass field expects.""" + # pylint: disable=too-many-return-statements + + simpletype = SIMPLE_NAMES_TO_TYPES.get(typestr) + if simpletype is not None: + if type(value) is not simpletype: + # Special case: if they want to coerce ints to floats, do so. + if (self._coerce_to_float and simpletype is float + and type(value) is int): + return float(value) + _raise_type_error(fieldpath, type(value), (simpletype, )) + return value + if typestr.startswith('List[') and typestr.endswith(']'): + return self._sequence_from_input(cls, fieldpath, typestr, value, + 'List', list) + if typestr.startswith('Set[') and typestr.endswith(']'): + return self._sequence_from_input(cls, fieldpath, typestr, value, + 'Set', set) + if typestr.startswith('Optional[') and typestr.endswith(']'): + subtypestr = typestr[9:-1] + # Handle the 'None' case special and do the default + # thing otherwise. + if value is None: + return None + return self._value_from_input(cls, fieldpath, subtypestr, value) + + # Ok, its not a builtin type. It might be an enum or nested dataclass. + cls2 = getattr(inspect.getmodule(cls), typestr, None) + if cls2 is None: + raise RuntimeError(f"Unable to resolve '{typestr}'" + f" used by class '{cls.__name__}';" + f' make sure all nested types are declared' + f' in the global namespace of the module where' + f" '{cls.__name__} is defined.") + + if dataclasses.is_dataclass(cls2): + return self._dataclass_from_input(cls2, fieldpath, value) + + if issubclass(cls2, Enum): + return enum_by_value(cls2, value) + + raise TypeError( + f"Field '{fieldpath}' of type '{typestr}' is unsupported here.") + + def _dataclass_from_input(self, cls: Type, fieldpath: str, + values: dict) -> Any: + """Given a dict, instantiates a dataclass of the given type. + + The dict must be in the json-friendly format as emitted from + dataclass_to_dict. This means that sequence values such as tuples or + sets should be passed as lists, enums should be passed as their + associated values, and nested dataclasses should be passed as dicts. + """ + if not dataclasses.is_dataclass(cls): + raise TypeError(f'Passed class {cls} is not a dataclass.') + if not isinstance(values, dict): + raise TypeError("Expected a dict for 'values' arg.") + + # noinspection PyDataclass + fields = dataclasses.fields(cls) + fields_by_name = {f.name: f for f in fields} + args: Dict[str, Any] = {} + for key, value in values.items(): + field = fields_by_name.get(key) + if field is None: + raise AttributeError(f"'{cls.__name__}' has no '{key}' field.") + + typestr = _field_type_str(cls, field) + + subfieldpath = (f'{fieldpath}.{field.name}' + if fieldpath else field.name) + args[key] = self._value_from_input(cls, subfieldpath, typestr, + value) + + return cls(**args) + + def _sequence_from_input(self, cls: Type, fieldpath: str, typestr: str, + value: Any, seqtypestr: str, + seqtype: Type) -> Any: + # Because we are json-centric, we expect a list for all sequences. + if type(value) is not list: + raise TypeError(f'Invalid input value for "{fieldpath}";' + f' expected a list, got a {type(value).__name__}') + subtypestr = typestr[len(seqtypestr) + 1:-1] + return seqtype( + self._value_from_input(cls, fieldpath, subtypestr, i) + for i in value) diff --git a/tools/efro/util.py b/tools/efro/util.py index 9e02eda4..1969130d 100644 --- a/tools/efro/util.py +++ b/tools/efro/util.py @@ -7,6 +7,7 @@ from __future__ import annotations import datetime import time import weakref +from enum import Enum from typing import TYPE_CHECKING, cast, TypeVar, Generic if TYPE_CHECKING: @@ -18,12 +19,38 @@ T = TypeVar('T') TVAL = TypeVar('TVAL') TARG = TypeVar('TARG') TRET = TypeVar('TRET') +TENUM = TypeVar('TENUM', bound=Enum) class _EmptyObj: pass +def enum_by_value(cls: Type[TENUM], value: Any) -> TENUM: + """Create an enum from a value. + + This is basically the same as doing 'obj = EnumType(value)' except + that it works around an issue where a reference loop is created + if an exception is thrown due to an invalid value. Since we disable + the cyclic garbage collector for most of the time, such loops can lead + to our objects sticking around longer than we want. + This issue has been submitted to Python as a bug so hopefully we can + remove this eventually if it gets fixed: https://bugs.python.org/issue42248 + """ + + # Note: we don't recreate *ALL* the functionality of the Enum constructor + # such as the _missing_ hook; but this should cover our basic needs. + value2member_map = getattr(cls, '_value2member_map_') + assert value2member_map is not None + try: + out = value2member_map[value] + assert isinstance(out, cls) + return out + except KeyError: + raise ValueError('%r is not a valid %s' % + (value, cls.__name__)) from None + + def utc_now() -> datetime.datetime: """Get offset-aware current utc time.