mirror of
https://github.com/RYDE-WORK/ballistica.git
synced 2026-01-27 01:13:13 +08:00
56 lines
2.2 KiB
Python
56 lines
2.2 KiB
Python
# Released under the MIT License. See LICENSE for details.
|
|
#
|
|
"""Functionality related to capturing nested dataclass paths."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from typing import TYPE_CHECKING
|
|
|
|
from efro.dataclassio._base import _parse_annotated, _get_origin
|
|
from efro.dataclassio._prep import PrepSession
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Any, Dict, Type, Tuple, Optional, List, Set
|
|
|
|
|
|
class FieldStoragePathCapture:
|
|
"""Utility for obtaining dataclass storage paths in a type safe way.
|
|
|
|
Given dataclass instance foo, FieldStoragePathCapture(foo).bar.eep
|
|
will return 'bar.eep' (or something like 'b.e' if storagenames are
|
|
overridden). This can be combined with type-checking tricks that
|
|
return foo in the type-checker's eyes while returning
|
|
FieldStoragePathCapture(foo) at runtime in order to grant a measure
|
|
of type safety to specifying field paths for things such as db
|
|
queries. Be aware, however, that the type-checker will incorrectly
|
|
think these lookups are returning actual attr values when they
|
|
are actually returning strings.
|
|
"""
|
|
|
|
def __init__(self, obj: Any, path: List[str] = None):
|
|
if path is None:
|
|
path = []
|
|
if not dataclasses.is_dataclass(obj):
|
|
raise TypeError(f'Expected a dataclass type/instance;'
|
|
f' got {type(obj)}.')
|
|
self._cls = obj if isinstance(obj, type) else type(obj)
|
|
self._path = path
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
prep = PrepSession(explicit=False).prep_dataclass(self._cls,
|
|
recursion_level=0)
|
|
try:
|
|
anntype = prep.annotations[name]
|
|
except KeyError as exc:
|
|
raise AttributeError(f'{type(self)} has no {name} field.') from exc
|
|
anntype, ioattrs = _parse_annotated(anntype)
|
|
storagename = (name if (ioattrs is None or ioattrs.storagename is None)
|
|
else ioattrs.storagename)
|
|
origin = _get_origin(anntype)
|
|
path = self._path + [storagename]
|
|
|
|
if dataclasses.is_dataclass(origin):
|
|
return FieldStoragePathCapture(origin, path=path)
|
|
return '.'.join(path)
|