diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 14ea8dc3..f326bd72 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1426,11 +1426,11 @@ struct enum_base { }), none(), none(), "" ); - #define PYBIND11_ENUM_OP_STRICT(op, expr) \ + #define PYBIND11_ENUM_OP_STRICT(op, expr, strict_behavior) \ m_base.attr(op) = cpp_function( \ [](object a, object b) { \ if (!a.get_type().is(b.get_type())) \ - throw type_error("Expected an enumeration of matching type!"); \ + strict_behavior; \ return expr; \ }, \ is_method(m_base)) @@ -1460,14 +1460,16 @@ struct enum_base { PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b); } } else { - PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b))); - PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b))); + PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false); + PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true); if (is_arithmetic) { - PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b)); - PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b)); - PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b)); - PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b)); + #define THROW throw type_error("Expected an enumeration of matching type!"); + PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), THROW); + PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), THROW); + PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), THROW); + PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), THROW); + #undef THROW } } diff --git a/tests/test_enum.py b/tests/test_enum.py index b1a5089e..d0989adc 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -47,10 +47,12 @@ Members: EOne : Docstring for EOne''' - # no TypeError exception for unscoped enum ==/!= int comparisons + # Unscoped enums will accept ==/!= int comparisons y = m.UnscopedEnum.ETwo assert y == 2 + assert 2 == y assert y != 3 + assert 3 != y assert int(m.UnscopedEnum.ETwo) == 2 assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo" @@ -75,11 +77,20 @@ def test_scoped_enum(): z = m.ScopedEnum.Two assert m.test_scoped_enum(z) == "ScopedEnum::Two" - # expected TypeError exceptions for scoped enum ==/!= int comparisons + # Scoped enums will *NOT* accept ==/!= int comparisons (Will always return False) + assert not z == 3 + assert not 3 == z + assert z != 3 + assert 3 != z + # Scoped enums will *NOT* accept >, <, >= and <= int comparisons (Will throw exceptions) with pytest.raises(TypeError): - assert z == 2 + z > 3 with pytest.raises(TypeError): - assert z != 3 + z < 3 + with pytest.raises(TypeError): + z >= 3 + with pytest.raises(TypeError): + z <= 3 # order assert m.ScopedEnum.Two < m.ScopedEnum.Three