Added enum key support to ba.entity

This commit is contained in:
Eric Froemling 2021-03-11 16:00:26 -06:00
parent 4bd1e12f50
commit 1c8e3fd01e
No known key found for this signature in database
GPG Key ID: 89C93F0F8D6D5A98
10 changed files with 161 additions and 47 deletions

View File

@ -1098,6 +1098,7 @@
<w>keepalives</w> <w>keepalives</w>
<w>keepaway</w> <w>keepaway</w>
<w>keeprefs</w> <w>keeprefs</w>
<w>keyfilt</w>
<w>keylayout</w> <w>keylayout</w>
<w>keypresses</w> <w>keypresses</w>
<w>keystr</w> <w>keystr</w>

View File

@ -479,6 +479,7 @@
<w>jmessage</w> <w>jmessage</w>
<w>keepalives</w> <w>keepalives</w>
<w>keycode</w> <w>keycode</w>
<w>keyfilt</w>
<w>keysyms</w> <w>keysyms</w>
<w>keywds</w> <w>keywds</w>
<w>khronos</w> <w>khronos</w>

View File

@ -1,5 +1,5 @@
<!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND --> <!-- THIS FILE IS AUTO GENERATED; DO NOT EDIT BY HAND -->
<h4><em>last updated on 2021-03-03 for Ballistica version 1.6.0 build 20319</em></h4> <h4><em>last updated on 2021-03-11 for Ballistica version 1.6.0 build 20323</em></h4>
<p>This page documents the Python classes and functions in the 'ba' module, <p>This page documents the Python classes and functions in the 'ba' module,
which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p> which are the ones most relevant to modding in Ballistica. If you come across something you feel should be included here or could be better explained, please <a href="mailto:support@froemling.net">let me know</a>. Happy modding!</p>
<hr> <hr>

View File

@ -28,6 +28,13 @@ class EnumTest(Enum):
SECOND = 1 SECOND = 1
@unique
class EnumTest2(Enum):
"""Testing..."""
FIRST = 0
SECOND = 1
class SubCompoundTest(entity.CompoundValue): class SubCompoundTest(entity.CompoundValue):
"""Testing...""" """Testing..."""
subval = entity.Field('b', entity.BoolValue()) subval = entity.Field('b', entity.BoolValue())
@ -58,12 +65,14 @@ class EntityTest(entity.Entity):
slval = entity.ListField('sl', entity.StringValue()) slval = entity.ListField('sl', entity.StringValue())
tval2 = entity.Field('t2', entity.DateTimeValue()) tval2 = entity.Field('t2', entity.DateTimeValue())
str_int_dict = entity.DictField('sd', str, entity.IntValue()) str_int_dict = entity.DictField('sd', str, entity.IntValue())
enum_int_dict = entity.DictField('ed', EnumTest, entity.IntValue())
compoundlist = entity.CompoundListField('l', CompoundTest()) compoundlist = entity.CompoundListField('l', CompoundTest())
compoundlist2 = entity.CompoundListField('l2', CompoundTest()) compoundlist2 = entity.CompoundListField('l2', CompoundTest())
compoundlist3 = entity.CompoundListField('l3', CompoundTest2()) compoundlist3 = entity.CompoundListField('l3', CompoundTest2())
compounddict = entity.CompoundDictField('td', str, CompoundTest()) compounddict = entity.CompoundDictField('td', str, CompoundTest())
compounddict2 = entity.CompoundDictField('td2', str, CompoundTest()) compounddict2 = entity.CompoundDictField('td2', str, CompoundTest())
compounddict3 = entity.CompoundDictField('td3', str, CompoundTest2()) compounddict3 = entity.CompoundDictField('td3', str, CompoundTest2())
compounddict4 = entity.CompoundDictField('td4', EnumTest, CompoundTest())
fval2 = entity.Field('f2', entity.Float3Value()) fval2 = entity.Field('f2', entity.Float3Value())
@ -117,6 +126,27 @@ def test_entity_values() -> None:
assert static_type_equals(ent.str_int_dict['foo'], int) assert static_type_equals(ent.str_int_dict['foo'], int)
assert ent.str_int_dict['foo'] == 123 assert ent.str_int_dict['foo'] == 123
# Simple dict with enum key.
ent.enum_int_dict[EnumTest.FIRST] = 234
assert ent.enum_int_dict[EnumTest.FIRST] == 234
# Set with incorrect key type should give TypeError.
with pytest.raises(TypeError):
ent.enum_int_dict[0] = 123 # type: ignore
with pytest.raises(TypeError):
ent.enum_int_dict[EnumTest2.FIRST] = 123 # type: ignore
# And set with incorrect value type should do same.
with pytest.raises(TypeError):
ent.enum_int_dict[EnumTest.FIRST] = 'bar' # type: ignore
# Make sure is stored as underlying type.
assert ent.d_data['ed'] == {0: 234}
# Make sure invalid raw enum values are caught.
ent2 = EntityTest()
ent2.set_data({})
ent2.set_data({'ed': {0: 111}})
with pytest.raises(ValueError):
ent2.set_data({'ed': {5: 111}})
# Waaah; this works at runtime, but it seems that we'd need # Waaah; this works at runtime, but it seems that we'd need
# to have BoundDictField inherit from Mapping for mypy to accept this. # to have BoundDictField inherit from Mapping for mypy to accept this.
# (which seems to get a bit ugly, but may be worth revisiting) # (which seems to get a bit ugly, but may be worth revisiting)
@ -164,7 +194,7 @@ def test_entity_values_2() -> None:
with pytest.raises(TypeError): with pytest.raises(TypeError):
_cdval2 = ent.compounddict.add(1) # type: ignore _cdval2 = ent.compounddict.add(1) # type: ignore
# Hmm; should this throw a TypeError and not a KeyError?.. # Hmm; should this throw a TypeError and not a KeyError?..
with pytest.raises(KeyError): with pytest.raises(TypeError):
_cdval3 = ent.compounddict[1] # type: ignore _cdval3 = ent.compounddict[1] # type: ignore
assert static_type_equals(ent.compounddict['foo'], CompoundTest) assert static_type_equals(ent.compounddict['foo'], CompoundTest)
@ -172,7 +202,19 @@ def test_entity_values_2() -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
# noinspection PyTypeHints # noinspection PyTypeHints
ent.enumval = None # type: ignore ent.enumval = None # type: ignore
assert ent.enumval == EnumTest.FIRST assert ent.enumval is EnumTest.FIRST
# Compound dict with enum key.
assert not ent.compounddict4 # bool operator
_cd4val = ent.compounddict4.add(EnumTest.FIRST)
assert ent.compounddict4 # bool operator
ent.compounddict4[EnumTest.FIRST].isubval = 222
assert ent.compounddict4[EnumTest.FIRST].isubval == 222
with pytest.raises(TypeError):
ent.compounddict4[0].isubval = 222 # type: ignore
assert static_type_equals(ent.compounddict4[EnumTest.FIRST], CompoundTest)
# Make sure enum keys are stored as underlying type.
assert ent.d_data['td4'] == {0: {'i': 222, 'l': []}}
# Optional Enum value # Optional Enum value
ent.enumval2 = None ent.enumval2 = None
@ -186,6 +228,9 @@ def test_entity_values_2() -> None:
assert static_type_equals(ent.grp.compoundlist[0], SubCompoundTest) assert static_type_equals(ent.grp.compoundlist[0], SubCompoundTest)
assert static_type_equals(ent.grp.compoundlist[0].subval, bool) assert static_type_equals(ent.grp.compoundlist[0].subval, bool)
# Make sure we can digest the same data we spit out.
ent.set_data(ent.d_data)
def test_field_copies() -> None: def test_field_copies() -> None:
"""Test copying various values between fields.""" """Test copying various values between fields."""

View File

@ -264,3 +264,5 @@ if TYPE_CHECKING:
# noinspection PyPep8Naming # noinspection PyPep8Naming
def Call(*_args: Any, **_keywds: Any) -> Any: def Call(*_args: Any, **_keywds: Any) -> Any:
... ...
Call = Call

View File

@ -4,10 +4,28 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from typing import Any, Type
def dict_key_to_raw(key: Any, keytype: Type) -> Any:
"""Given a key value from the world, filter to stored key."""
if not isinstance(key, keytype):
raise TypeError(
f'Invalid key type; expected {keytype}, got {type(key)}.')
if issubclass(keytype, Enum):
return key.value
return key
def dict_key_from_raw(key: Any, keytype: Type) -> Any:
"""Given internal key, filter to world visible type."""
if issubclass(keytype, Enum):
return keytype(key)
return key
class DataHandler: class DataHandler:

View File

@ -6,8 +6,10 @@ from __future__ import annotations
import copy import copy
import logging import logging
from enum import Enum
from typing import TYPE_CHECKING, Generic, TypeVar, overload from typing import TYPE_CHECKING, Generic, TypeVar, overload
from efro.util import enum_by_value
from efro.entity._base import BaseField from efro.entity._base import BaseField
from efro.entity._support import (BoundCompoundValue, BoundListField, from efro.entity._support import (BoundCompoundValue, BoundListField,
BoundDictField, BoundCompoundListField, BoundDictField, BoundCompoundListField,
@ -186,7 +188,6 @@ class ListField(BaseField, Generic[T]):
# When accessed on a FieldInspector we return a sub-field FieldInspector. # When accessed on a FieldInspector we return a sub-field FieldInspector.
# When accessed on an instance we return a BoundListField. # When accessed on an instance we return a BoundListField.
# noinspection DuplicatedCode
if TYPE_CHECKING: if TYPE_CHECKING:
# Access via type gives our field; via an instance gives a bound field. # Access via type gives our field; via an instance gives a bound field.
@ -233,7 +234,6 @@ class DictField(BaseField, Generic[TK, T]):
def get_default_data(self) -> dict: def get_default_data(self) -> dict:
return {} return {}
# noinspection DuplicatedCode
def filter_input(self, data: Any, error: bool) -> Any: def filter_input(self, data: Any, error: bool) -> Any:
# If we were passed a BoundDictField, operate on its raw values # If we were passed a BoundDictField, operate on its raw values
@ -247,12 +247,29 @@ class DictField(BaseField, Generic[TK, T]):
data = {} data = {}
data_out = {} data_out = {}
for key, val in data.items(): for key, val in data.items():
if not isinstance(key, self._keytype):
# For enum keys, make sure its a valid enum.
if issubclass(self._keytype, Enum):
try:
_enumval = enum_by_value(self._keytype, key)
except Exception as exc:
if error:
raise ValueError(f'No enum of type {self._keytype}'
f' exists with value {key}') from exc
logging.error('Ignoring invalid key type for %s: %s', self,
data)
continue
# For all other keys we can check for exact types.
elif not isinstance(key, self._keytype):
if error: if error:
raise TypeError('invalid key type') raise TypeError(
f'Invalid key type; expected {self._keytype},'
f' got {type(key)}.')
logging.error('Ignoring invalid key type for %s: %s', self, logging.error('Ignoring invalid key type for %s: %s', self,
data) data)
continue continue
data_out[key] = self.d_value.filter_input(val, error=error) data_out[key] = self.d_value.filter_input(val, error=error)
return data_out return data_out
@ -261,7 +278,6 @@ class DictField(BaseField, Generic[TK, T]):
# change the dict, but we can prune completely if empty (and allowed) # change the dict, but we can prune completely if empty (and allowed)
return not data and not self._store_default return not data and not self._store_default
# noinspection DuplicatedCode
if TYPE_CHECKING: if TYPE_CHECKING:
# Return our field if accessed via type and bound-dict-field # Return our field if accessed via type and bound-dict-field
@ -339,7 +355,6 @@ class CompoundListField(BaseField, Generic[TC]):
# We can also optionally prune the whole list if empty and allowed. # We can also optionally prune the whole list if empty and allowed.
return not data and not self._store_default return not data and not self._store_default
# noinspection DuplicatedCode
if TYPE_CHECKING: if TYPE_CHECKING:
@overload @overload
@ -436,10 +451,10 @@ class CompoundDictField(BaseField, Generic[TK, TC]):
# This doesnt actually exist for us, but want the type-checker # This doesnt actually exist for us, but want the type-checker
# to think it does (see TYPE_CHECKING note below). # to think it does (see TYPE_CHECKING note below).
self.d_data: Any self.d_data: Any
self.d_keytype = keytype self.d_keytype = keytype
self._store_default = store_default self._store_default = store_default
# noinspection DuplicatedCode
def filter_input(self, data: Any, error: bool) -> dict: def filter_input(self, data: Any, error: bool) -> dict:
if not isinstance(data, dict): if not isinstance(data, dict):
if error: if error:
@ -448,12 +463,29 @@ class CompoundDictField(BaseField, Generic[TK, TC]):
data = {} data = {}
data_out = {} data_out = {}
for key, val in data.items(): for key, val in data.items():
if not isinstance(key, self.d_keytype):
# For enum keys, make sure its a valid enum.
if issubclass(self.d_keytype, Enum):
try:
_enumval = enum_by_value(self.d_keytype, key)
except Exception as exc:
if error:
raise ValueError(f'No enum of type {self.d_keytype}'
f' exists with value {key}') from exc
logging.error('Ignoring invalid key type for %s: %s', self,
data)
continue
# For all other keys we can check for exact types.
elif not isinstance(key, self.d_keytype):
if error: if error:
raise TypeError('invalid key type') raise TypeError(
f'Invalid key type; expected {self.d_keytype},'
f' got {type(key)}.')
logging.error('Ignoring invalid key type for %s: %s', self, logging.error('Ignoring invalid key type for %s: %s', self,
data) data)
continue continue
data_out[key] = self.d_value.filter_input(val, error=error) data_out[key] = self.d_value.filter_input(val, error=error)
return data_out return data_out
@ -472,7 +504,6 @@ class CompoundDictField(BaseField, Generic[TK, TC]):
# ONLY overriding these in type-checker land to clarify types. # ONLY overriding these in type-checker land to clarify types.
# (see note in BaseField) # (see note in BaseField)
# noinspection DuplicatedCode
if TYPE_CHECKING: if TYPE_CHECKING:
@overload @overload

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar, Generic, overload from typing import TYPE_CHECKING, TypeVar, Generic, overload
from efro.entity._base import BaseField from efro.entity._base import (BaseField, dict_key_to_raw, dict_key_from_raw)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import (Optional, Tuple, Type, Any, Dict, List, Union) from typing import (Optional, Tuple, Type, Any, Dict, List, Union)
@ -215,35 +215,30 @@ class BoundDictField(Generic[TKey, T]):
def __repr__(self) -> str: def __repr__(self) -> str:
return '{' + ', '.join( return '{' + ', '.join(
repr(key) + ': ' + repr(self.d_field.d_value.filter_output(val)) repr(dict_key_from_raw(key, self._keytype)) + ': ' +
repr(self.d_field.d_value.filter_output(val))
for key, val in self.d_data.items()) + '}' for key, val in self.d_data.items()) + '}'
def __len__(self) -> int: def __len__(self) -> int:
return len(self.d_data) return len(self.d_data)
def __getitem__(self, key: TKey) -> T: def __getitem__(self, key: TKey) -> T:
if not isinstance(key, self._keytype): keyfilt = dict_key_to_raw(key, self._keytype)
raise TypeError( typedval: T = self.d_field.d_value.filter_output(self.d_data[keyfilt])
f'Invalid key type {type(key)}; expected {self._keytype}')
assert isinstance(key, self._keytype)
typedval: T = self.d_field.d_value.filter_output(self.d_data[key])
return typedval return typedval
def get(self, key: TKey, default: Optional[T] = None) -> Optional[T]: def get(self, key: TKey, default: Optional[T] = None) -> Optional[T]:
"""Get a value if present, or a default otherwise.""" """Get a value if present, or a default otherwise."""
if not isinstance(key, self._keytype): keyfilt = dict_key_to_raw(key, self._keytype)
raise TypeError( if keyfilt not in self.d_data:
f'Invalid key type {type(key)}; expected {self._keytype}')
assert isinstance(key, self._keytype)
if key not in self.d_data:
return default return default
typedval: T = self.d_field.d_value.filter_output(self.d_data[key]) typedval: T = self.d_field.d_value.filter_output(self.d_data[keyfilt])
return typedval return typedval
def __setitem__(self, key: TKey, value: T) -> None: def __setitem__(self, key: TKey, value: T) -> None:
if not isinstance(key, self._keytype): keyfilt = dict_key_to_raw(key, self._keytype)
raise TypeError('Expected str index.') self.d_data[keyfilt] = self.d_field.d_value.filter_input(value,
self.d_data[key] = self.d_field.d_value.filter_input(value, error=True) error=True)
def __contains__(self, key: TKey) -> bool: def __contains__(self, key: TKey) -> bool:
return key in self.d_data return key in self.d_data
@ -253,7 +248,9 @@ class BoundDictField(Generic[TKey, T]):
def keys(self) -> List[TKey]: def keys(self) -> List[TKey]:
"""Return a list of our keys.""" """Return a list of our keys."""
return list(self.d_data.keys()) return [
dict_key_from_raw(k, self._keytype) for k in self.d_data.keys()
]
def values(self) -> List[T]: def values(self) -> List[T]:
"""Return a list of our values.""" """Return a list of our values."""
@ -264,7 +261,8 @@ class BoundDictField(Generic[TKey, T]):
def items(self) -> List[Tuple[TKey, T]]: def items(self) -> List[Tuple[TKey, T]]:
"""Return a list of item/value pairs.""" """Return a list of item/value pairs."""
return [(key, self.d_field.d_value.filter_output(value)) return [(dict_key_from_raw(key, self._keytype),
self.d_field.d_value.filter_output(value))
for key, value in self.d_data.items()] for key, value in self.d_data.items()]
@ -413,13 +411,16 @@ class BoundCompoundDictField(Generic[TKey, TCompound]):
def get(self, key): def get(self, key):
"""return a value if present; otherwise None.""" """return a value if present; otherwise None."""
data = self.d_data.get(key) keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
data = self.d_data.get(keyfilt)
if data is not None: if data is not None:
return BoundCompoundValue(self.d_field.d_value, data) return BoundCompoundValue(self.d_field.d_value, data)
return None return None
def __getitem__(self, key): def __getitem__(self, key):
return BoundCompoundValue(self.d_field.d_value, self.d_data[key]) keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
return BoundCompoundValue(self.d_field.d_value,
self.d_data[keyfilt])
def values(self): def values(self):
"""Return a list of our values.""" """Return a list of our values."""
@ -429,21 +430,22 @@ class BoundCompoundDictField(Generic[TKey, TCompound]):
def items(self): def items(self):
"""Return key/value pairs for all dict entries.""" """Return key/value pairs for all dict entries."""
return [(key, BoundCompoundValue(self.d_field.d_value, value)) return [(dict_key_from_raw(key, self.d_field.d_keytype),
BoundCompoundValue(self.d_field.d_value, value))
for key, value in self.d_data.items()] for key, value in self.d_data.items()]
def add(self, key: TKey) -> TCompound: def add(self, key: TKey) -> TCompound:
"""Add an entry into the dict, returning it. """Add an entry into the dict, returning it.
Any existing value is replaced.""" Any existing value is replaced."""
if not isinstance(key, self.d_field.d_keytype): keyfilt = dict_key_to_raw(key, self.d_field.d_keytype)
raise TypeError(f'expected key type {self.d_field.d_keytype};'
f' got {type(key)}')
# Push the entity default into data and then let it fill in # Push the entity default into data and then let it fill in
# any children/etc. # any children/etc.
self.d_data[key] = (self.d_field.d_value.filter_input( self.d_data[keyfilt] = (self.d_field.d_value.filter_input(
self.d_field.d_value.get_default_data(), error=True)) self.d_field.d_value.get_default_data(), error=True))
return BoundCompoundValue(self.d_field.d_value, self.d_data[key]) return BoundCompoundValue(self.d_field.d_value,
self.d_data[keyfilt])
def __len__(self) -> int: def __len__(self) -> int:
return len(self.d_data) return len(self.d_data)
@ -456,4 +458,7 @@ class BoundCompoundDictField(Generic[TKey, TCompound]):
def keys(self) -> List[TKey]: def keys(self) -> List[TKey]:
"""Return a list of our keys.""" """Return a list of our keys."""
return list(self.d_data.keys()) return [
dict_key_from_raw(k, self.d_field.d_keytype)
for k in self.d_data.keys()
]

View File

@ -7,11 +7,13 @@ from __future__ import annotations
import datetime import datetime
import time import time
import weakref import weakref
import functools
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, cast, TypeVar, Generic from typing import TYPE_CHECKING, cast, TypeVar, Generic
if TYPE_CHECKING: if TYPE_CHECKING:
import asyncio import asyncio
from efro.call import Call as Call # 'as Call' so we re-export.
from weakref import ReferenceType from weakref import ReferenceType
from typing import Any, Dict, Callable, Optional, Type from typing import Any, Dict, Callable, Optional, Type
@ -27,6 +29,12 @@ class _EmptyObj:
pass pass
if TYPE_CHECKING:
Call = Call
else:
Call = functools.partial
def enum_by_value(cls: Type[TENUM], value: Any) -> TENUM: def enum_by_value(cls: Type[TENUM], value: Any) -> TENUM:
"""Create an enum from a value. """Create an enum from a value.

View File

@ -79,7 +79,7 @@ def build_apple(arch: str, debug: bool = False) -> None:
# txt = replace_one(txt, '_lzma _', '#_lzma _') # txt = replace_one(txt, '_lzma _', '#_lzma _')
# Turn off bzip2 module. # Turn off bzip2 module.
txt = replace_one(txt, '_bz2 _b', '#_bz2 _b') # txt = replace_one(txt, '_bz2 _b', '#_bz2 _b')
# Turn off openssl module (only if not doing openssl). # Turn off openssl module (only if not doing openssl).
if not ENABLE_OPENSSL: if not ENABLE_OPENSSL:
@ -150,11 +150,14 @@ def build_apple(arch: str, debug: bool = False) -> None:
# libs we're not using. # libs we're not using.
srctxt = '$$(PYTHON_DIR-$1)/dist/lib/libpython$(PYTHON_VER).a: ' srctxt = '$$(PYTHON_DIR-$1)/dist/lib/libpython$(PYTHON_VER).a: '
if PY38: if PY38:
txt = replace_one( # Note: now just keeping everything on.
txt, srctxt, assert ENABLE_OPENSSL
'$$(PYTHON_DIR-$1)/dist/lib/libpython$(PYTHON_VER).a: ' + if bool(False):
('build/$2/Support/OpenSSL ' if ENABLE_OPENSSL else '') + txt = replace_one(
'build/$2/Support/XZ $$(PYTHON_DIR-$1)/Makefile\n#' + srctxt) txt, srctxt,
'$$(PYTHON_DIR-$1)/dist/lib/libpython$(PYTHON_VER).a: ' +
('build/$2/Support/OpenSSL ' if ENABLE_OPENSSL else '') +
'build/$2/Support/XZ $$(PYTHON_DIR-$1)/Makefile\n#' + srctxt)
else: else:
txt = replace_one( txt = replace_one(
txt, srctxt, txt, srctxt,