From e9d31b012c3e0d028e78f0d2857b7845cf7349c8 Mon Sep 17 00:00:00 2001 From: Eric Date: Wed, 31 Jan 2024 13:58:45 -0800 Subject: [PATCH] Latest public/internal sync. --- .efrocachemap | 56 ++++---- CHANGELOG.md | 4 +- src/assets/ba_data/python/baenv.py | 8 +- src/ballistica/shared/ballistica.cc | 2 +- tests/test_efro/test_dataclassio.py | 189 ++++++++++++++++++++++++++- tools/bacommon/transfer.py | 8 +- tools/efro/dataclassio/__init__.py | 3 +- tools/efro/dataclassio/_api.py | 32 ++++- tools/efro/dataclassio/_base.py | 148 ++++++++++++++------- tools/efro/dataclassio/_inputter.py | 94 +++++++++---- tools/efro/dataclassio/_outputter.py | 61 ++++++++- tools/efro/dataclassio/_prep.py | 14 +- 12 files changed, 504 insertions(+), 115 deletions(-) diff --git a/.efrocachemap b/.efrocachemap index 398381a0..873034ca 100644 --- a/.efrocachemap +++ b/.efrocachemap @@ -4060,26 +4060,26 @@ "build/assets/windows/Win32/ucrtbased.dll": "2def5335207d41b21b9823f6805997f1", "build/assets/windows/Win32/vc_redist.x86.exe": "b08a55e2e77623fe657bea24f223a3ae", "build/assets/windows/Win32/vcruntime140d.dll": "865b2af4d1e26a1a8073c89acb06e599", - "build/prefab/full/linux_arm64_gui/debug/ballisticakit": "26eea64d4509875c9a88da74f49e675c", - "build/prefab/full/linux_arm64_gui/release/ballisticakit": "0a39319a89364641f3bb0598821b4288", - "build/prefab/full/linux_arm64_server/debug/dist/ballisticakit_headless": "84567063607be0227ef779027e12d19d", - "build/prefab/full/linux_arm64_server/release/dist/ballisticakit_headless": "f4458855192dedd13a28d36dc3962890", - "build/prefab/full/linux_x86_64_gui/debug/ballisticakit": "4c0679b0157c2dd63519e5225d99359d", - "build/prefab/full/linux_x86_64_gui/release/ballisticakit": "335a3f06dc6dd361d6122fd9143124ae", - "build/prefab/full/linux_x86_64_server/debug/dist/ballisticakit_headless": "041a300c9fa99c82395e1ebc66e81fe3", - "build/prefab/full/linux_x86_64_server/release/dist/ballisticakit_headless": "181145bf30e752991860acd0e44f972c", - "build/prefab/full/mac_arm64_gui/debug/ballisticakit": "8531542c35242bcbffc0309cef10b2b8", - "build/prefab/full/mac_arm64_gui/release/ballisticakit": "48cdebbdea839f6b8fc8f5cb69d7f961", - "build/prefab/full/mac_arm64_server/debug/dist/ballisticakit_headless": "159003daac99048702c74120be565bad", - "build/prefab/full/mac_arm64_server/release/dist/ballisticakit_headless": "51c9582a1efaae50e1c435c13c390855", - "build/prefab/full/mac_x86_64_gui/debug/ballisticakit": "d66c11ebe6d9035ea7e86b362f8505a1", - "build/prefab/full/mac_x86_64_gui/release/ballisticakit": "1f8113ffba1d000120bf83ac268c603b", - "build/prefab/full/mac_x86_64_server/debug/dist/ballisticakit_headless": "6f2a68c0370061a2913278d97b039ecc", - "build/prefab/full/mac_x86_64_server/release/dist/ballisticakit_headless": "471e7f81fac96b4db752c5cdaeed7168", - "build/prefab/full/windows_x86_gui/debug/BallisticaKit.exe": "94916e80a9d7bc7801db666beceea026", - "build/prefab/full/windows_x86_gui/release/BallisticaKit.exe": "1bc098ae93dd18143fb64ae5cbc33c19", - "build/prefab/full/windows_x86_server/debug/dist/BallisticaKitHeadless.exe": "da99cef03f12a6ff2c0065f4616262f2", - "build/prefab/full/windows_x86_server/release/dist/BallisticaKitHeadless.exe": "14b67157a3bf57b9de067089476f79d5", + "build/prefab/full/linux_arm64_gui/debug/ballisticakit": "d1d989de9e44829ce7adc6348cad34f1", + "build/prefab/full/linux_arm64_gui/release/ballisticakit": "d27e236d62e3db407c61902f0768b209", + "build/prefab/full/linux_arm64_server/debug/dist/ballisticakit_headless": "148a0c692fd30c3027158866a1c6c157", + "build/prefab/full/linux_arm64_server/release/dist/ballisticakit_headless": "0836f235c538b20dd2187071dc82a9c0", + "build/prefab/full/linux_x86_64_gui/debug/ballisticakit": "c928cdc074b9cb8f752ca049fb30fcf9", + "build/prefab/full/linux_x86_64_gui/release/ballisticakit": "b66bd051975628898fb66d291188824f", + "build/prefab/full/linux_x86_64_server/debug/dist/ballisticakit_headless": "7b3579d629ad99f032c4b2d821f7e348", + "build/prefab/full/linux_x86_64_server/release/dist/ballisticakit_headless": "08d11c347fed9b4d2b6f582c92321ed0", + "build/prefab/full/mac_arm64_gui/debug/ballisticakit": "f51e6dbccdeb8b64163029d58168d6d3", + "build/prefab/full/mac_arm64_gui/release/ballisticakit": "5c250868de853f0bcdbfd671e5863e0b", + "build/prefab/full/mac_arm64_server/debug/dist/ballisticakit_headless": "bc49209413eacf23bd6aa8cae47f7324", + "build/prefab/full/mac_arm64_server/release/dist/ballisticakit_headless": "40f7edd3b8e2d5cf2869cdaf12459fbe", + "build/prefab/full/mac_x86_64_gui/debug/ballisticakit": "9530446001824359b438d64054a4fa39", + "build/prefab/full/mac_x86_64_gui/release/ballisticakit": "5db3cac8a2cfdb5d56cb7579d32f17c6", + "build/prefab/full/mac_x86_64_server/debug/dist/ballisticakit_headless": "db8d6083d7bbdf78855c70affc3792df", + "build/prefab/full/mac_x86_64_server/release/dist/ballisticakit_headless": "c4d5c5387cc15f9c83dd41ce75e5cba5", + "build/prefab/full/windows_x86_gui/debug/BallisticaKit.exe": "2accd53f262abd82afcc9f9b73f26f2e", + "build/prefab/full/windows_x86_gui/release/BallisticaKit.exe": "3e2d7f9d4c7c350af1e21a8acbb3dec6", + "build/prefab/full/windows_x86_server/debug/dist/BallisticaKitHeadless.exe": "4d629a6f6029e191dd341e0a2a21d50b", + "build/prefab/full/windows_x86_server/release/dist/BallisticaKitHeadless.exe": "feeeb28a230759fb5283474f82fc2451", "build/prefab/lib/linux_arm64_gui/debug/libballisticaplus.a": "8709ad96140d71760c2f493ee8bd7c43", "build/prefab/lib/linux_arm64_gui/release/libballisticaplus.a": "ee829cd5488e9750570dc6f602d65589", "build/prefab/lib/linux_arm64_server/debug/libballisticaplus.a": "8709ad96140d71760c2f493ee8bd7c43", @@ -4096,14 +4096,14 @@ "build/prefab/lib/mac_x86_64_gui/release/libballisticaplus.a": "79117cbfdf695298e1d9ae997d990c4d", "build/prefab/lib/mac_x86_64_server/debug/libballisticaplus.a": "984f0990a8e4cca29a382d70e51cc051", "build/prefab/lib/mac_x86_64_server/release/libballisticaplus.a": "79117cbfdf695298e1d9ae997d990c4d", - "build/prefab/lib/windows/Debug_Win32/BallisticaKitGenericPlus.lib": "97a0aee0716397c0394c620b0cdc8cfa", - "build/prefab/lib/windows/Debug_Win32/BallisticaKitGenericPlus.pdb": "5edf5fd129429079b24368da6c792c44", - "build/prefab/lib/windows/Debug_Win32/BallisticaKitHeadlessPlus.lib": "e453446a36102733a1f0db636fafb704", - "build/prefab/lib/windows/Debug_Win32/BallisticaKitHeadlessPlus.pdb": "dfb843bbc924daf7a2e2a2eb6b4811df", - "build/prefab/lib/windows/Release_Win32/BallisticaKitGenericPlus.lib": "09bb45bcbfad7c0f63b9494ceca669cc", - "build/prefab/lib/windows/Release_Win32/BallisticaKitGenericPlus.pdb": "c8d10517d61dc5c4d7c94a5eccecab4a", - "build/prefab/lib/windows/Release_Win32/BallisticaKitHeadlessPlus.lib": "4944d18bb54894b0488cbdaa7b2ef06f", - "build/prefab/lib/windows/Release_Win32/BallisticaKitHeadlessPlus.pdb": "d17c4758367051e734601018b081f786", + "build/prefab/lib/windows/Debug_Win32/BallisticaKitGenericPlus.lib": "810dd57e398827fb3abc488a4185a0b3", + "build/prefab/lib/windows/Debug_Win32/BallisticaKitGenericPlus.pdb": "c10f8fb6e748f6244753adef81ef5ed4", + "build/prefab/lib/windows/Debug_Win32/BallisticaKitHeadlessPlus.lib": "d30f36c9c925e52e94ef39afc8b0a35e", + "build/prefab/lib/windows/Debug_Win32/BallisticaKitHeadlessPlus.pdb": "88431c1a6372b951ce85c2c73bc7f8c5", + "build/prefab/lib/windows/Release_Win32/BallisticaKitGenericPlus.lib": "928e9fd64e81de0d43e433a4474826cb", + "build/prefab/lib/windows/Release_Win32/BallisticaKitGenericPlus.pdb": "ae0ce0dd3541770cb7bd93997aca3e04", + "build/prefab/lib/windows/Release_Win32/BallisticaKitHeadlessPlus.lib": "0de083bc720affcbab4fbf0c121a84fe", + "build/prefab/lib/windows/Release_Win32/BallisticaKitHeadlessPlus.pdb": "299be2aa45e5d943b56828f065911172", "src/assets/ba_data/python/babase/_mgen/__init__.py": "f885fed7f2ed98ff2ba271f9dbe3391c", "src/assets/ba_data/python/babase/_mgen/enums.py": "b611c090513a21e2fe90e56582724e9d", "src/ballistica/base/mgen/pyembed/binding_base.inc": "72bfed2cce8ff19741989dec28302f3f", diff --git a/CHANGELOG.md b/CHANGELOG.md index 00f77a1e..7be40ad5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -### 1.7.33 (build 21762, api 8, 2024-01-24) +### 1.7.33 (build 21763, api 8, 2024-01-31) - Stress test input-devices are now a bit smarter; they won't press any buttons while UIs are up (this could cause lots of chaos if it happened). - Added a 'Show Demos When Idle' option in advanced settings. If enabled, the @@ -20,6 +20,8 @@ languages; I feel it helps keep logic more understandable and should help us catch problems where a base class changes or removes a method and child classes forget to adapt to the change. +- Implemented `efro.dataclassio.IOMultiType` which will make my life a lot + easier. ### 1.7.32 (build 21741, api 8, 2023-12-20) - Fixed a screen message that no one will ever see (Thanks vishal332008?...) diff --git a/src/assets/ba_data/python/baenv.py b/src/assets/ba_data/python/baenv.py index 01c47c91..8a4b85c3 100644 --- a/src/assets/ba_data/python/baenv.py +++ b/src/assets/ba_data/python/baenv.py @@ -52,7 +52,7 @@ if TYPE_CHECKING: # Build number and version of the ballistica binary we expect to be # using. -TARGET_BALLISTICA_BUILD = 21762 +TARGET_BALLISTICA_BUILD = 21763 TARGET_BALLISTICA_VERSION = '1.7.33' @@ -287,9 +287,9 @@ def _setup_certs(contains_python_dist: bool) -> None: import certifi # Let both OpenSSL and requests (if present) know to use this. - os.environ['SSL_CERT_FILE'] = os.environ[ - 'REQUESTS_CA_BUNDLE' - ] = certifi.where() + os.environ['SSL_CERT_FILE'] = os.environ['REQUESTS_CA_BUNDLE'] = ( + certifi.where() + ) def _setup_paths( diff --git a/src/ballistica/shared/ballistica.cc b/src/ballistica/shared/ballistica.cc index 75f4b12f..f88717ed 100644 --- a/src/ballistica/shared/ballistica.cc +++ b/src/ballistica/shared/ballistica.cc @@ -39,7 +39,7 @@ auto main(int argc, char** argv) -> int { namespace ballistica { // These are set automatically via script; don't modify them here. -const int kEngineBuildNumber = 21762; +const int kEngineBuildNumber = 21763; const char* kEngineVersion = "1.7.33"; const int kEngineApiVersion = 8; diff --git a/tests/test_efro/test_dataclassio.py b/tests/test_efro/test_dataclassio.py index 1c91d47d..427ffff3 100644 --- a/tests/test_efro/test_dataclassio.py +++ b/tests/test_efro/test_dataclassio.py @@ -5,10 +5,17 @@ from __future__ import annotations -from enum import Enum import datetime +from enum import Enum from dataclasses import field, dataclass -from typing import TYPE_CHECKING, Any, Sequence, Annotated +from typing import ( + TYPE_CHECKING, + Any, + Sequence, + Annotated, + assert_type, + assert_never, +) from typing_extensions import override import pytest @@ -24,10 +31,11 @@ from efro.dataclassio import ( Codec, DataclassFieldLookup, IOExtendedData, + IOMultiType, ) if TYPE_CHECKING: - pass + from typing import Self class _EnumTest(Enum): @@ -1069,3 +1077,178 @@ def test_soft_default() -> None: todict = dataclass_to_dict(orig) assert todict == {'ival': 2} assert dataclass_from_dict(_TestClassE8, todict) == orig + + +class MTTestTypeID(Enum): + """IDs for our multi-type class.""" + + CLASS_1 = 'm1' + CLASS_2 = 'm2' + + +class MTTestBase(IOMultiType[MTTestTypeID]): + """Our multi-type class. + + These top level multi-type classes are special parent classes + that know about all of their child classes and how to serialize + & deserialize them using explicit type ids. We can then use the + parent class in annotations and dataclassio will do the right thing. + Useful for stuff like Message classes. + """ + + @override + @classmethod + def get_type(cls, type_id: MTTestTypeID) -> type[MTTestBase]: + """Return the subclass for each of our type-ids.""" + + # This uses assert_never() to ensure we cover all cases in the + # enum. Though this is less efficient than looking up by dict + # would be. If we had lots of values we could also support lazy + # loading by importing classes only when their value is being + # requested. + val: type[MTTestBase] + if type_id is MTTestTypeID.CLASS_1: + val = MTTestClass1 + elif type_id is MTTestTypeID.CLASS_2: + val = MTTestClass2 + else: + assert_never(type_id) + return val + + @override + @classmethod + def get_type_id(cls) -> MTTestTypeID: + """Provide the type-id for this subclass.""" + # If we wanted, we could just maintain a static mapping + # of types-to-ids here, but there are benefits to letting + # each child class speak for itself. Namely that we can + # do lazy-loading and don't need to have all types present + # here. + + # So we'll let all our child classes override this. + raise NotImplementedError() + + +@ioprepped +@dataclass(frozen=True) # Frozen so we can test in set() +class MTTestClass1(MTTestBase): + """A test child-class for use with our multi-type class.""" + + ival: int + + @override + @classmethod + def get_type_id(cls) -> MTTestTypeID: + return MTTestTypeID.CLASS_1 + + +@ioprepped +@dataclass(frozen=True) # Frozen so we can test in set() +class MTTestClass2(MTTestBase): + """Another test child-class for use with our multi-type class.""" + + sval: str + + @override + @classmethod + def get_type_id(cls) -> MTTestTypeID: + return MTTestTypeID.CLASS_2 + + +def test_multi_type() -> None: + """Test IOMultiType stuff.""" + # pylint: disable=too-many-locals + + # Test converting single instances back and forth. + val1: MTTestBase = MTTestClass1(ival=123) + tpname = MTTestBase.ID_STORAGE_NAME + outdict = dataclass_to_dict(val1) + assert outdict == {'ival': 123, tpname: 'm1'} + val2: MTTestBase = MTTestClass2(sval='whee') + outdict2 = dataclass_to_dict(val2) + assert outdict2 == {'sval': 'whee', tpname: 'm2'} + + # Make sure types and values work for both concrete types and the + # multi-type. + assert_type(dataclass_from_dict(MTTestClass1, outdict), MTTestClass1) + assert_type(dataclass_from_dict(MTTestBase, outdict), MTTestBase) + + assert dataclass_from_dict(MTTestClass1, outdict) == val1 + assert dataclass_from_dict(MTTestClass2, outdict2) == val2 + assert dataclass_from_dict(MTTestBase, outdict) == val1 + assert dataclass_from_dict(MTTestBase, outdict2) == val2 + + # Now test our multi-type embedded in other classes. We should be + # able to throw a mix of things in there and have them deserialize + # back the types we started with. + + # Individual values: + + @ioprepped + @dataclass + class _TestContainerClass1: + obj_a: MTTestBase + obj_b: MTTestBase + + container1 = _TestContainerClass1( + obj_a=MTTestClass1(234), obj_b=MTTestClass2('987') + ) + outdict = dataclass_to_dict(container1) + container1b = dataclass_from_dict(_TestContainerClass1, outdict) + assert container1 == container1b + + # Lists: + + @ioprepped + @dataclass + class _TestContainerClass2: + objs: list[MTTestBase] + + container2 = _TestContainerClass2( + objs=[MTTestClass1(111), MTTestClass2('bbb')] + ) + outdict = dataclass_to_dict(container2) + container2b = dataclass_from_dict(_TestContainerClass2, outdict) + assert container2 == container2b + + # Dict values: + + @ioprepped + @dataclass + class _TestContainerClass3: + objs: dict[int, MTTestBase] + + container3 = _TestContainerClass3( + objs={1: MTTestClass1(456), 2: MTTestClass2('gronk')} + ) + outdict = dataclass_to_dict(container3) + container3b = dataclass_from_dict(_TestContainerClass3, outdict) + assert container3 == container3b + + # Tuples: + + @ioprepped + @dataclass + class _TestContainerClass4: + objs: tuple[MTTestBase, MTTestBase] + + container4 = _TestContainerClass4( + objs=(MTTestClass1(932), MTTestClass2('potato')) + ) + outdict = dataclass_to_dict(container4) + container4b = dataclass_from_dict(_TestContainerClass4, outdict) + assert container4 == container4b + + # Sets (note: dataclasses must be frozen for this to work): + + @ioprepped + @dataclass + class _TestContainerClass5: + objs: set[MTTestBase] + + container5 = _TestContainerClass5( + objs={MTTestClass1(424), MTTestClass2('goo')} + ) + outdict = dataclass_to_dict(container5) + container5b = dataclass_from_dict(_TestContainerClass5, outdict) + assert container5 == container5b diff --git a/tools/bacommon/transfer.py b/tools/bacommon/transfer.py index a53c6153..65221cbd 100644 --- a/tools/bacommon/transfer.py +++ b/tools/bacommon/transfer.py @@ -18,10 +18,10 @@ if TYPE_CHECKING: @ioprepped @dataclass class DirectoryManifestFile: - """Describes metadata and hashes for a file in a manifest.""" + """Describes a file in a manifest.""" - filehash: Annotated[str, IOAttrs('h')] - filesize: Annotated[int, IOAttrs('s')] + hash_sha256: Annotated[str, IOAttrs('h')] + size: Annotated[int, IOAttrs('s')] @ioprepped @@ -67,7 +67,7 @@ class DirectoryManifest: return ( filepath, DirectoryManifestFile( - filehash=sha.hexdigest(), filesize=filesize + hash_sha256=sha.hexdigest(), size=filesize ), ) diff --git a/tools/efro/dataclassio/__init__.py b/tools/efro/dataclassio/__init__.py index 56c87b10..9c6ae98a 100644 --- a/tools/efro/dataclassio/__init__.py +++ b/tools/efro/dataclassio/__init__.py @@ -11,7 +11,7 @@ data formats in a nondestructive manner. from __future__ import annotations from efro.util import set_canonical_module_names -from efro.dataclassio._base import Codec, IOAttrs, IOExtendedData +from efro.dataclassio._base import Codec, IOAttrs, IOExtendedData, IOMultiType from efro.dataclassio._prep import ( ioprep, ioprepped, @@ -33,6 +33,7 @@ __all__ = [ 'Codec', 'IOAttrs', 'IOExtendedData', + 'IOMultiType', 'ioprep', 'ioprepped', 'will_ioprep', diff --git a/tools/efro/dataclassio/_api.py b/tools/efro/dataclassio/_api.py index ddadd4d8..2fa308fa 100644 --- a/tools/efro/dataclassio/_api.py +++ b/tools/efro/dataclassio/_api.py @@ -27,7 +27,7 @@ class JsonStyle(Enum): """Different style types for json.""" # Single line, no spaces, no sorting. Not deterministic. - # Use this for most storage purposes. + # Use this where speed is more important than determinism. FAST = 'fast' # Single line, no spaces, sorted keys. Deterministic. @@ -40,7 +40,9 @@ class JsonStyle(Enum): def dataclass_to_dict( - obj: Any, codec: Codec = Codec.JSON, coerce_to_float: bool = True + obj: Any, + codec: Codec = Codec.JSON, + coerce_to_float: bool = True, ) -> dict: """Given a dataclass object, return a json-friendly dict. @@ -89,6 +91,28 @@ def dataclass_to_json( return json.dumps(jdict, separators=(',', ':'), sort_keys=sort_keys) +# @overload +# def dataclass_from_dict( +# cls: type[T], +# values: dict, +# codec: Codec = Codec.JSON, +# coerce_to_float: bool = True, +# allow_unknown_attrs: bool = True, +# discard_unknown_attrs: bool = False, +# ) -> T: ... + + +# @overload +# def dataclass_from_dict( +# cls: IOTypeMap, +# values: dict, +# codec: Codec = Codec.JSON, +# coerce_to_float: bool = True, +# allow_unknown_attrs: bool = True, +# discard_unknown_attrs: bool = False, +# ) -> Any: ... + + def dataclass_from_dict( cls: type[T], values: dict, @@ -120,13 +144,15 @@ def dataclass_from_dict( exported back to a dict, unless discard_unknown_attrs is True, in which case they will simply be discarded. """ - return _Inputter( + val = _Inputter( cls, codec=codec, coerce_to_float=coerce_to_float, allow_unknown_attrs=allow_unknown_attrs, discard_unknown_attrs=discard_unknown_attrs, ).run(values) + assert isinstance(val, cls) + return val def dataclass_from_json( diff --git a/tools/efro/dataclassio/_base.py b/tools/efro/dataclassio/_base.py index afe19deb..abcb24a7 100644 --- a/tools/efro/dataclassio/_base.py +++ b/tools/efro/dataclassio/_base.py @@ -8,13 +8,13 @@ import dataclasses import typing import datetime from enum import Enum -from typing import TYPE_CHECKING, get_args +from typing import TYPE_CHECKING, get_args, TypeVar, Generic # noinspection PyProtectedMember from typing import _AnnotatedAlias # type: ignore if TYPE_CHECKING: - from typing import Any, Callable + from typing import Any, Callable, Literal, ClassVar, Self # Types which we can pass through as-is. SIMPLE_TYPES = {int, bool, str, float, type(None)} @@ -24,23 +24,6 @@ SIMPLE_TYPES = {int, bool, str, float, type(None)} EXTRA_ATTRS_ATTR = '_DCIOEXATTRS' -def _raise_type_error( - fieldpath: str, valuetype: type, expected: tuple[type, ...] -) -> None: - """Raise an error when a field value's type does not match expected.""" - assert isinstance(expected, tuple) - assert all(isinstance(e, type) for e in expected) - if len(expected) == 1: - expected_str = expected[0].__name__ - else: - expected_str = ' | '.join(t.__name__ for t in expected) - raise TypeError( - f'Invalid value type for "{fieldpath}";' - f' expected "{expected_str}", got' - f' "{valuetype.__name__}".' - ) - - class Codec(Enum): """Specifies expected data format exported to or imported from.""" @@ -78,32 +61,41 @@ class IOExtendedData: """ -def _is_valid_for_codec(obj: Any, codec: Codec) -> bool: - """Return whether a value consists solely of json-supported types. +KeyT = TypeVar('KeyT', bound=Enum) - Note that this does not include things like tuples which are - implicitly translated to lists by python's json module. + +class IOMultiType(Generic[KeyT]): + """A base class for types that can map to multiple dataclass types. + + This allows construction of high level base classes (for example + a 'Message' type). These types can then be used as annotations in + dataclasses, and dataclassio will serialize/deserialize instances + based on their subtype plus simple embedded type-id values. + + See tests/test_efro/test_dataclassio.py for an example of this. """ - if obj is None: - return True - objtype = type(obj) - if objtype in (int, float, str, bool): - return True - if objtype is dict: - # JSON 'objects' supports only string dict keys, but all value types. - return all( - isinstance(k, str) and _is_valid_for_codec(v, codec) - for k, v in obj.items() - ) - if objtype is list: - return all(_is_valid_for_codec(elem, codec) for elem in obj) + # Serialized data will store individual object ids to this key. If + # this value is ever problematic, it should be possible to override + # it in a subclass. + ID_STORAGE_NAME = '_iotype' - # A few things are valid in firestore but not json. - if issubclass(objtype, datetime.datetime) or objtype is bytes: - return codec is Codec.FIRESTORE + @classmethod + def get_key_type(cls) -> type[Enum]: + """Return the enum type we use as a key.""" + out: type[Enum] = cls.__orig_bases__[0].__args__[0] # type: ignore + assert issubclass(out, Enum) + return out - return False + @classmethod + def get_type_id(cls) -> KeyT: + """Return the type id for this subclass.""" + raise NotImplementedError() + + @classmethod + def get_type(cls, type_id: KeyT) -> type[Self]: + """Return a specific subclass given an id.""" + raise NotImplementedError() class IOAttrs: @@ -192,7 +184,7 @@ class IOAttrs: """Ensure the IOAttrs instance is ok to use with the provided field.""" # Turning off store_default requires the field to have either - # a default or a a default_factory or for us to have soft equivalents. + # a default or a default_factory or for us to have soft equivalents. if not self.store_default: field_default_factory: Any = field.default_factory @@ -241,6 +233,52 @@ class IOAttrs: ) +def _raise_type_error( + fieldpath: str, valuetype: type, expected: tuple[type, ...] +) -> None: + """Raise an error when a field value's type does not match expected.""" + assert isinstance(expected, tuple) + assert all(isinstance(e, type) for e in expected) + if len(expected) == 1: + expected_str = expected[0].__name__ + else: + expected_str = ' | '.join(t.__name__ for t in expected) + raise TypeError( + f'Invalid value type for "{fieldpath}";' + f' expected "{expected_str}", got' + f' "{valuetype.__name__}".' + ) + + +def _is_valid_for_codec(obj: Any, codec: Codec) -> bool: + """Return whether a value consists solely of json-supported types. + + Note that this does not include things like tuples which are + implicitly translated to lists by python's json module. + """ + if obj is None: + return True + + objtype = type(obj) + if objtype in (int, float, str, bool): + return True + if objtype is dict: + # JSON 'objects' supports only string dict keys, but all value + # types. + return all( + isinstance(k, str) and _is_valid_for_codec(v, codec) + for k, v in obj.items() + ) + if objtype is list: + return all(_is_valid_for_codec(elem, codec) for elem in obj) + + # A few things are valid in firestore but not json. + if issubclass(objtype, datetime.datetime) or objtype is bytes: + return codec is Codec.FIRESTORE + + return False + + def _get_origin(anntype: Any) -> Any: """Given a type annotation, return its origin or itself if there is none. @@ -255,9 +293,9 @@ def _get_origin(anntype: Any) -> Any: def _parse_annotated(anntype: Any) -> tuple[Any, IOAttrs | None]: """Parse Annotated() constructs, returning annotated type & IOAttrs.""" - # If we get an Annotated[foo, bar, eep] we take - # foo as the actual type, and we look for IOAttrs instances in - # bar/eep to affect our behavior. + # If we get an Annotated[foo, bar, eep] we take foo as the actual + # type, and we look for IOAttrs instances in bar/eep to affect our + # behavior. ioattrs: IOAttrs | None = None if isinstance(anntype, _AnnotatedAlias): annargs = get_args(anntype) @@ -270,8 +308,8 @@ def _parse_annotated(anntype: Any) -> tuple[Any, IOAttrs | None]: ) ioattrs = annarg - # I occasionally just throw a 'x' down when I mean IOAttrs('x'); - # catch these mistakes. + # I occasionally just throw a 'x' down when I mean + # IOAttrs('x'); catch these mistakes. elif isinstance(annarg, (str, int, float, bool)): raise RuntimeError( f'Raw {type(annarg)} found in Annotated[] entry:' @@ -279,3 +317,21 @@ def _parse_annotated(anntype: Any) -> tuple[Any, IOAttrs | None]: ) anntype = annargs[0] return anntype, ioattrs + + +def _get_multitype_type( + cls: type[IOMultiType], fieldpath: str, val: Any +) -> type[Any]: + if not isinstance(val, dict): + raise ValueError( + f"Found a {type(val)} at '{fieldpath}'; expected a dict." + ) + storename = cls.ID_STORAGE_NAME + id_val = val.get(storename) + if id_val is None: + raise ValueError( + f"Expected a '{storename}'" f" value for object at '{fieldpath}'." + ) + id_enum_type = cls.get_key_type() + id_enum = id_enum_type(id_val) + return cls.get_type(id_enum) diff --git a/tools/efro/dataclassio/_inputter.py b/tools/efro/dataclassio/_inputter.py index 97075528..4c41582f 100644 --- a/tools/efro/dataclassio/_inputter.py +++ b/tools/efro/dataclassio/_inputter.py @@ -13,7 +13,7 @@ import dataclasses import typing import types import datetime -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING from efro.util import enum_by_value, check_utc from efro.dataclassio._base import ( @@ -25,6 +25,8 @@ from efro.dataclassio._base import ( SIMPLE_TYPES, _raise_type_error, IOExtendedData, + _get_multitype_type, + IOMultiType, ) from efro.dataclassio._prep import PrepSession @@ -34,13 +36,11 @@ if TYPE_CHECKING: from efro.dataclassio._base import IOAttrs from efro.dataclassio._outputter import _Outputter -T = TypeVar('T') - -class _Inputter(Generic[T]): +class _Inputter: def __init__( self, - cls: type[T], + cls: type[Any], codec: Codec, coerce_to_float: bool, allow_unknown_attrs: bool = True, @@ -59,27 +59,39 @@ class _Inputter(Generic[T]): ' when allow_unknown_attrs is False.' ) - def run(self, values: dict) -> T: + def run(self, values: dict) -> Any: """Do the thing.""" - # For special extended data types, call their 'will_output' callback. - tcls = self._cls + outcls: type[Any] - if issubclass(tcls, IOExtendedData): + # If we're dealing with a multi-type class, figure out the + # top level type we're going to. + if issubclass(self._cls, IOMultiType): + type_id_val = values.get(self._cls.ID_STORAGE_NAME) + if type_id_val is None: + raise ValueError( + f'No type id value present for multi-type object:' + f' {values}.' + ) + type_id_enum = self._cls.get_key_type() + enum_val = type_id_enum(type_id_val) + outcls = self._cls.get_type(enum_val) + else: + outcls = self._cls + + # FIXME - should probably move this into _dataclass_from_input + # so it can work on nested values. + if issubclass(outcls, IOExtendedData): is_ext = True - tcls.will_input(values) + outcls.will_input(values) else: is_ext = False - out = self._dataclass_from_input(self._cls, '', values) - assert isinstance(out, self._cls) + out = self._dataclass_from_input(outcls, '', values) + assert isinstance(out, outcls) if is_ext: - # mypy complains that we're no longer returning a T - # if we operate on out directly. - out2 = out - assert isinstance(out2, IOExtendedData) - out2.did_input() + out.did_input() return out @@ -111,8 +123,8 @@ class _Inputter(Generic[T]): # noinspection PyPep8 if origin is typing.Union or origin is types.UnionType: # Currently, the only unions we support are None/Value - # (translated from Optional), which we verified on prep. - # So let's treat this as a simple optional case. + # (translated from Optional), which we verified on prep. So + # let's treat this as a simple optional case. if value is None: return None childanntypes_l = [ @@ -123,13 +135,15 @@ class _Inputter(Generic[T]): cls, fieldpath, childanntypes_l[0], value, ioattrs ) - # Everything below this point assumes the annotation type resolves - # to a concrete type. (This should have been verified at prep time). + # Everything below this point assumes the annotation type + # resolves to a concrete type. (This should have been verified + # at prep time). assert isinstance(origin, type) if origin in SIMPLE_TYPES: if type(value) is not origin: - # Special case: if they want to coerce ints to floats, do so. + # Special case: if they want to coerce ints to floats, + # do so. if ( self._coerce_to_float and origin is float @@ -157,6 +171,16 @@ class _Inputter(Generic[T]): if dataclasses.is_dataclass(origin): return self._dataclass_from_input(origin, fieldpath, value) + # ONLY consider something as a multi-type when it's not a + # dataclass (all dataclasses inheriting from the multi-type + # should just be processed as dataclasses). + if issubclass(origin, IOMultiType): + return self._dataclass_from_input( + _get_multitype_type(anntype, fieldpath, value), + fieldpath, + value, + ) + if issubclass(origin, Enum): return enum_by_value(origin, value) @@ -228,10 +252,23 @@ class _Inputter(Generic[T]): f.name: _parse_annotated(prep.annotations[f.name]) for f in fields } + # Special case: if this is a multi-type class it probably has a + # type attr. Ignore that while parsing since we already have a + # definite type and it will just pollute extra-attrs otherwise. + if issubclass(cls, IOMultiType): + type_id_store_name = cls.ID_STORAGE_NAME + else: + type_id_store_name = None + # Go through all data in the input, converting it to either dataclass # args or extra data. args: dict[str, Any] = {} for rawkey, value in values.items(): + + # Ignore _iotype or whatnot. + if type_id_store_name is not None and rawkey == type_id_store_name: + continue + key = prep.storage_names_to_attr_names.get(rawkey, rawkey) field = fields_by_name.get(key) @@ -473,6 +510,19 @@ class _Inputter(Generic[T]): # We contain elements of some specified type. assert len(childanntypes) == 1 childanntype = childanntypes[0] + + # If our annotation type inherits from IOMultiType, use type-id + # values to determine which type to load for each element. + if issubclass(childanntype, IOMultiType): + return seqtype( + self._dataclass_from_input( + _get_multitype_type(childanntype, fieldpath, i), + fieldpath, + i, + ) + for i in value + ) + return seqtype( self._value_from_input(cls, fieldpath, childanntype, i, ioattrs) for i in value diff --git a/tools/efro/dataclassio/_outputter.py b/tools/efro/dataclassio/_outputter.py index 03e1d20f..198ad93e 100644 --- a/tools/efro/dataclassio/_outputter.py +++ b/tools/efro/dataclassio/_outputter.py @@ -25,6 +25,7 @@ from efro.dataclassio._base import ( SIMPLE_TYPES, _raise_type_error, IOExtendedData, + IOMultiType, ) from efro.dataclassio._prep import PrepSession @@ -49,6 +50,8 @@ class _Outputter: assert dataclasses.is_dataclass(self._obj) # For special extended data types, call their 'will_output' callback. + # FIXME - should probably move this into _process_dataclass so it + # can work on nested values. if isinstance(self._obj, IOExtendedData): self._obj.will_output() @@ -139,6 +142,17 @@ class _Outputter: if self._create: assert out is not None out.update(extra_attrs) + + # If this obj inherits from multi-type, store its type id. + if isinstance(obj, IOMultiType): + type_id = obj.get_type_id() + # Sanity checks; make sure looking up this id gets us this type. + assert isinstance(type_id.value, str) + assert obj.get_type(type_id) is type(obj) + if self._create: + assert out is not None + out[obj.ID_STORAGE_NAME] = type_id.value + return out def _process_value( @@ -231,6 +245,7 @@ class _Outputter: f'Expected a list for {fieldpath};' f' found a {type(value)}' ) + childanntypes = typing.get_args(anntype) # 'Any' type children; make sure they are valid values for @@ -246,8 +261,37 @@ class _Outputter: # Hmm; should we do a copy here? return value if self._create else None - # We contain elements of some specified type. + # We contain elements of some single specified type. assert len(childanntypes) == 1 + childanntype = childanntypes[0] + + # If that type is a multi-type, we determine our type per-object. + if issubclass(childanntype, IOMultiType): + # In the multi-type case, we use each object's own type + # to do its conversion, but lets at least make sure each + # of those types inherits from the annotated multi-type + # class. + for x in value: + if not isinstance(x, childanntype): + raise ValueError( + f"Found a {type(x)} value under '{fieldpath}'." + f' Everything must inherit from' + f' {childanntype}.' + ) + + if self._create: + out: list[Any] = [] + for x in value: + # We know these are dataclasses so no need to do + # the generic _process_value. + out.append(self._process_dataclass(cls, x, fieldpath)) + return out + for x in value: + # We know these are dataclasses so no need to do + # the generic _process_value. + self._process_dataclass(cls, x, fieldpath) + + # Normal non-multitype case; everything's got the same type. if self._create: return [ self._process_value( @@ -307,6 +351,21 @@ class _Outputter: ) return self._process_dataclass(cls, value, fieldpath) + # ONLY consider something as a multi-type when it's not a + # dataclass (all dataclasses inheriting from the multi-type should + # just be processed as dataclasses). + if issubclass(origin, IOMultiType): + # In the multi-type case, we use each object's own type to + # do its conversion, but lets at least make sure each of + # those types inherits from the annotated multi-type class. + if not isinstance(value, origin): + raise ValueError( + f"Found a {type(value)} value at '{fieldpath}'." + f' It is expected to inherit from {origin}.' + ) + + return self._process_dataclass(cls, value, fieldpath) + if issubclass(origin, Enum): if not isinstance(value, origin): raise TypeError( diff --git a/tools/efro/dataclassio/_prep.py b/tools/efro/dataclassio/_prep.py index 0800b559..d7f8e828 100644 --- a/tools/efro/dataclassio/_prep.py +++ b/tools/efro/dataclassio/_prep.py @@ -17,7 +17,12 @@ import datetime from typing import TYPE_CHECKING, TypeVar, get_type_hints # noinspection PyProtectedMember -from efro.dataclassio._base import _parse_annotated, _get_origin, SIMPLE_TYPES +from efro.dataclassio._base import ( + _parse_annotated, + _get_origin, + SIMPLE_TYPES, + IOMultiType, +) if TYPE_CHECKING: from typing import Any @@ -260,6 +265,13 @@ class PrepSession: origin = _get_origin(anntype) + # If we inherit from IOMultiType, we use its type map to + # determine which type we're going to instead of the annotation. + # And we can't really check those types because they are + # lazy-loaded. So I guess we're done here. + if issubclass(origin, IOMultiType): + return + # noinspection PyPep8 if origin is typing.Union or origin is types.UnionType: self.prep_union(