diff --git a/tests/test_efro/test_entity.py b/tests/test_efro/test_entity.py index 1d47bc6d..915fbe5b 100644 --- a/tests/test_efro/test_entity.py +++ b/tests/test_efro/test_entity.py @@ -78,6 +78,8 @@ class EntityTest(entity.Entity): def test_entity_values() -> None: """Test various entity assigns for value and type correctness.""" + # pylint: disable=too-many-statements + ent = EntityTest() # Simple int field. @@ -127,7 +129,9 @@ def test_entity_values() -> None: assert ent.str_int_dict['foo'] == 123 # Simple dict with enum key. + assert EnumTest.FIRST not in ent.enum_int_dict ent.enum_int_dict[EnumTest.FIRST] = 234 + assert EnumTest.FIRST in ent.enum_int_dict assert ent.enum_int_dict[EnumTest.FIRST] == 234 # Set with incorrect key type should give TypeError. with pytest.raises(TypeError): @@ -158,6 +162,7 @@ def test_entity_values() -> None: def test_entity_values_2() -> None: """Test various entity assigns for value and type correctness.""" + # pylint: disable=too-many-statements ent = EntityTest() @@ -206,8 +211,10 @@ def test_entity_values_2() -> None: # Compound dict with enum key. assert not ent.compounddict4 # bool operator + assert EnumTest.FIRST not in ent.compounddict4 _cd4val = ent.compounddict4.add(EnumTest.FIRST) assert ent.compounddict4 # bool operator + assert EnumTest.FIRST in ent.compounddict4 ent.compounddict4[EnumTest.FIRST].isubval = 222 assert ent.compounddict4[EnumTest.FIRST].isubval == 222 with pytest.raises(TypeError): diff --git a/tools/efro/entity/_support.py b/tools/efro/entity/_support.py index d696079b..99ceeded 100644 --- a/tools/efro/entity/_support.py +++ b/tools/efro/entity/_support.py @@ -241,10 +241,12 @@ class BoundDictField(Generic[TKey, T]): error=True) def __contains__(self, key: TKey) -> bool: - return key in self.d_data + keyfilt = dict_key_to_raw(key, self._keytype) + return keyfilt in self.d_data def __delitem__(self, key: TKey) -> None: - del self.d_data[key] + keyfilt = dict_key_to_raw(key, self._keytype) + del self.d_data[keyfilt] def keys(self) -> List[TKey]: """Return a list of our keys.""" @@ -451,10 +453,12 @@ class BoundCompoundDictField(Generic[TKey, TCompound]): return len(self.d_data) def __contains__(self, key: TKey) -> bool: - return key in self.d_data + keyfilt = dict_key_to_raw(key, self.d_field.d_keytype) + return keyfilt in self.d_data def __delitem__(self, key: TKey) -> None: - del self.d_data[key] + keyfilt = dict_key_to_raw(key, self.d_field.d_keytype) + del self.d_data[keyfilt] def keys(self) -> List[TKey]: """Return a list of our keys."""