From e5b42ef1fece01ec7207c7a3639355bfe5efeb13 Mon Sep 17 00:00:00 2001
From: Pim Schellart
Date: Tue, 2 Aug 2016 10:58:32 -0400
Subject: [PATCH] Enable comparisons between enums and their underlying types
---
example/example-constants-and-functions.py | 16 ++++++++++++++++
example/example-constants-and-functions.ref | 12 ++++++++++++
include/pybind11/pybind11.h | 19 +++++++++++--------
3 files changed, 39 insertions(+), 8 deletions(-)
diff --git a/example/example-constants-and-functions.py b/example/example-constants-and-functions.py
index 607450f9..f9292ee9 100755
--- a/example/example-constants-and-functions.py
+++ b/example/example-constants-and-functions.py
@@ -46,6 +46,22 @@ print("Inequality test 2: " + str(
ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) !=
ExampleWithEnum.test_function(ExampleWithEnum.ESecondMode)))
+print("Equality test 3: " + str(
+ ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) ==
+ int(ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode))))
+
+print("Inequality test 3: " + str(
+ ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) !=
+ int(ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode))))
+
+print("Equality test 4: " + str(
+ ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) ==
+ int(ExampleWithEnum.test_function(ExampleWithEnum.ESecondMode))))
+
+print("Inequality test 4: " + str(
+ ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode) !=
+ int(ExampleWithEnum.test_function(ExampleWithEnum.ESecondMode))))
+
x = {
ExampleWithEnum.test_function(ExampleWithEnum.EFirstMode): 1,
ExampleWithEnum.test_function(ExampleWithEnum.ESecondMode): 2
diff --git a/example/example-constants-and-functions.ref b/example/example-constants-and-functions.ref
index d2e1731a..1d08223f 100644
--- a/example/example-constants-and-functions.ref
+++ b/example/example-constants-and-functions.ref
@@ -30,6 +30,18 @@ ExampleWithEnum::test_function(enum=1)
ExampleWithEnum::test_function(enum=2)
Inequality test 2: True
ExampleWithEnum::test_function(enum=1)
+ExampleWithEnum::test_function(enum=1)
+Equality test 3: True
+ExampleWithEnum::test_function(enum=1)
+ExampleWithEnum::test_function(enum=1)
+Inequality test 3: False
+ExampleWithEnum::test_function(enum=1)
+ExampleWithEnum::test_function(enum=2)
+Equality test 4: False
+ExampleWithEnum::test_function(enum=1)
+ExampleWithEnum::test_function(enum=2)
+Inequality test 4: True
+ExampleWithEnum::test_function(enum=1)
ExampleWithEnum::test_function(enum=2)
ExampleWithEnum::test_function(enum=1)
ExampleWithEnum::test_function(enum=2)
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index a6255d8d..975cf144 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -1004,22 +1004,25 @@ private:
/// Binds C++ enumerations and enumeration classes to Python
template class enum_ : public class_ {
public:
+ using UnderlyingType = typename std::underlying_type::type;
template
enum_(const handle &scope, const char *name, const Extra&... extra)
: class_(scope, name, extra...), m_parent(scope) {
- auto entries = new std::unordered_map();
+ auto entries = new std::unordered_map();
this->def("__repr__", [name, entries](Type value) -> std::string {
- auto it = entries->find((int) value);
+ auto it = entries->find((UnderlyingType) value);
return std::string(name) + "." +
((it == entries->end()) ? std::string("???")
: std::string(it->second));
});
- this->def("__init__", [](Type& value, int i) { value = (Type)i; });
- this->def("__init__", [](Type& value, int i) { new (&value) Type((Type) i); });
- this->def("__int__", [](Type value) { return (int) value; });
+ this->def("__init__", [](Type& value, UnderlyingType i) { value = (Type)i; });
+ this->def("__init__", [](Type& value, UnderlyingType i) { new (&value) Type((Type) i); });
+ this->def("__int__", [](Type value) { return (UnderlyingType) value; });
this->def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; });
+ this->def("__eq__", [](const Type &value, UnderlyingType value2) { return value2 && value == value2; });
this->def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
- this->def("__hash__", [](const Type &value) { return (int) value; });
+ this->def("__ne__", [](const Type &value, UnderlyingType value2) { return value != value2; });
+ this->def("__hash__", [](const Type &value) { return (UnderlyingType) value; });
m_entries = entries;
}
@@ -1036,11 +1039,11 @@ public:
/// Add an enumeration entry
enum_& value(char const* name, Type value) {
this->attr(name) = pybind11::cast(value, return_value_policy::copy);
- (*m_entries)[(int) value] = name;
+ (*m_entries)[(UnderlyingType) value] = name;
return *this;
}
private:
- std::unordered_map *m_entries;
+ std::unordered_map *m_entries;
handle m_parent;
};