diff --git a/.idea/dictionaries/ericf.xml b/.idea/dictionaries/ericf.xml index 48ebd455..27aacacf 100644 --- a/.idea/dictionaries/ericf.xml +++ b/.idea/dictionaries/ericf.xml @@ -1426,6 +1426,7 @@ mythingie myweakcall mywidget + namecap namedarg nametext nameval @@ -2108,6 +2109,7 @@ strt strval subargs + subc subclassof subcontainer subcontainerheight diff --git a/ballisticacore-cmake/.idea/dictionaries/ericf.xml b/ballisticacore-cmake/.idea/dictionaries/ericf.xml index 197c076e..e6333108 100644 --- a/ballisticacore-cmake/.idea/dictionaries/ericf.xml +++ b/ballisticacore-cmake/.idea/dictionaries/ericf.xml @@ -626,6 +626,7 @@ mynode mystatspage mywidget + namecap nameval ndebug nearbytab @@ -950,6 +951,7 @@ strs strtof subargs + subc subclsssing subentities subfieldpath diff --git a/tests/test_efro/test_dataclassio.py b/tests/test_efro/test_dataclassio.py index 5af21aa9..8a5bbc9b 100644 --- a/tests/test_efro/test_dataclassio.py +++ b/tests/test_efro/test_dataclassio.py @@ -9,13 +9,14 @@ import datetime from dataclasses import field, dataclass from typing import (TYPE_CHECKING, Optional, List, Set, Any, Dict, Sequence, Union, Tuple) -from typing_extensions import Annotated +from typing_extensions import Annotated import pytest from efro.util import utc_now from efro.dataclassio import (dataclass_validate, dataclass_from_dict, - dataclass_to_dict, ioprepped, IOAttrs, Codec) + dataclass_to_dict, ioprepped, IOAttrs, Codec, + FieldStoragePathCapture) if TYPE_CHECKING: pass @@ -651,3 +652,34 @@ def test_name_clashes() -> None: class _TestClass3: ival: Annotated[int, IOAttrs(store_default=False)] = 4 ival2: Annotated[int, IOAttrs('ival')] = 5 + + +@ioprepped +@dataclass +class _SPTestClass1: + barf: int = 5 + eep: str = 'blah' + barf2: Annotated[int, IOAttrs('b')] = 5 + + +@ioprepped +@dataclass +class _SPTestClass2: + rah: bool = False + subc: _SPTestClass1 = field(default_factory=_SPTestClass1) + subc2: Annotated[_SPTestClass1, + IOAttrs('s')] = field(default_factory=_SPTestClass1) + + +def test_field_storage_path_capture() -> None: + """Test FieldStoragePathCapture functionality.""" + + obj = _SPTestClass2() + + namecap = FieldStoragePathCapture(obj) + assert namecap.subc.barf == 'subc.barf' + assert namecap.subc2.barf == 's.barf' + assert namecap.subc2.barf2 == 's.b' + + with pytest.raises(AttributeError): + assert namecap.nonexistent.barf == 's.barf' diff --git a/tools/efro/dataclassio.py b/tools/efro/dataclassio.py index 6cd8e76d..a7d05337 100644 --- a/tools/efro/dataclassio.py +++ b/tools/efro/dataclassio.py @@ -90,6 +90,47 @@ class IOAttrs: f' store_default=False cannot be set for it.') +class FieldStoragePathCapture: + """Utility for obtaining dataclass storage paths in a type safe way. + + Given dataclass instance foo, FieldStoragePathCapture(foo).bar.eep + will return 'bar.eep' (or something like 'b.e' if storagenames are + overrridden). This can be combined with type-checking tricks that + return foo in the type-checker's eyes while returning + FieldStoragePathCapture(foo) at runtime in order to grant a measure + of type safety to specifying field paths for things such as db + queries. Be aware, however, that the type-checker will incorrectly + think these lookups are returning actual attr values when they + are actually returning strings. + """ + + def __init__(self, obj: Any, path: List[str] = None): + if path is None: + path = [] + if not dataclasses.is_dataclass(obj): + raise TypeError(f'Expected a dataclass type/instance;' + f' got {type(obj)}.') + self._cls = obj if isinstance(obj, type) else type(obj) + self._path = path + + def __getattr__(self, name: str) -> Any: + prep = PrepSession(explicit=False).prep_dataclass(self._cls, + recursion_level=0) + try: + anntype = prep.annotations[name] + except KeyError as exc: + raise AttributeError(f'{type(self)} has no {name} field.') from exc + anntype, ioattrs = _parse_annotated(anntype) + storagename = (name if (ioattrs is None or ioattrs.storagename is None) + else ioattrs.storagename) + origin = _get_origin(anntype) + path = self._path + [storagename] + + if dataclasses.is_dataclass(origin): + return FieldStoragePathCapture(origin, path=path) + return '.'.join(path) + + def dataclass_to_dict(obj: Any, codec: Codec = Codec.JSON, coerce_to_float: bool = True) -> dict: