ballistica/tools/efro/dataclassio/_pathcapture.py

62 lines
2.4 KiB
Python

# Released under the MIT License. See LICENSE for details.
#
"""Functionality for importing, exporting, and validating dataclasses.
This allows complex nested dataclasses to be flattened to json-compatible
data and restored from said data. It also gracefully handles and preserves
unrecognized attribute data, allowing older clients to interact with newer
data formats in a nondestructive manner.
"""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING
from efro.dataclassio._base import _parse_annotated, _get_origin
from efro.dataclassio._prep import PrepSession
if TYPE_CHECKING:
from typing import Any, Dict, Type, Tuple, Optional, List, Set
class FieldStoragePathCapture:
"""Utility for obtaining dataclass storage paths in a type safe way.
Given dataclass instance foo, FieldStoragePathCapture(foo).bar.eep
will return 'bar.eep' (or something like 'b.e' if storagenames are
overridden). This can be combined with type-checking tricks that
return foo in the type-checker's eyes while returning
FieldStoragePathCapture(foo) at runtime in order to grant a measure
of type safety to specifying field paths for things such as db
queries. Be aware, however, that the type-checker will incorrectly
think these lookups are returning actual attr values when they
are actually returning strings.
"""
def __init__(self, obj: Any, path: List[str] = None):
if path is None:
path = []
if not dataclasses.is_dataclass(obj):
raise TypeError(f'Expected a dataclass type/instance;'
f' got {type(obj)}.')
self._cls = obj if isinstance(obj, type) else type(obj)
self._path = path
def __getattr__(self, name: str) -> Any:
prep = PrepSession(explicit=False).prep_dataclass(self._cls,
recursion_level=0)
try:
anntype = prep.annotations[name]
except KeyError as exc:
raise AttributeError(f'{type(self)} has no {name} field.') from exc
anntype, ioattrs = _parse_annotated(anntype)
storagename = (name if (ioattrs is None or ioattrs.storagename is None)
else ioattrs.storagename)
origin = _get_origin(anntype)
path = self._path + [storagename]
if dataclasses.is_dataclass(origin):
return FieldStoragePathCapture(origin, path=path)
return '.'.join(path)