diff --git a/docs/advanced.rst b/docs/advanced.rst index 8e1f7c3f..1158f053 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -298,13 +298,11 @@ helper class that is defined as follows: The macro :func:`PYBIND11_OVERLOAD_PURE` should be used for pure virtual functions, and :func:`PYBIND11_OVERLOAD` should be used for functions which have -a default implementation. - -There are also two alternate macros :func:`PYBIND11_OVERLOAD_PURE_NAME` and -:func:`PYBIND11_OVERLOAD_NAME` which take a string-valued name argument between -the *Parent class* and *Name of the function* slots. This is useful when the -C++ and Python versions of the function have different names, e.g. -``operator()`` vs ``__call__``. +a default implementation. There are also two alternate macros +:func:`PYBIND11_OVERLOAD_PURE_NAME` and :func:`PYBIND11_OVERLOAD_NAME` which +take a string-valued name argument between the *Parent class* and *Name of the +function* slots. This is useful when the C++ and Python versions of the +function have different names, e.g. ``operator()`` vs ``__call__``. The binding code also needs a few minor adaptations (highlighted): @@ -357,6 +355,25 @@ a virtual method call. Please take a look at the :ref:`macro_notes` before using this feature. +.. note:: + + When the overridden type returns a reference or pointer to a type that + pybind11 converts from Python (for example, numeric values, std::string, + and other built-in value-converting types), there are some limitations to + be aware of: + + - because in these cases there is no C++ variable to reference (the value + is stored in the referenced Python variable), pybind11 provides one in + the PYBIND11_OVERLOAD macros (when needed) with static storage duration. + Note that this means that invoking the overloaded method on *any* + instance will change the referenced value stored in *all* instances of + that type. + + - Attempts to modify a non-const reference will not have the desired + effect: it will change only the static cache variable, but this change + will not propagate to underlying Python instance, and the change will be + replaced the next time the overload is invoked. + .. seealso:: The file :file:`tests/test_virtual_functions.cpp` contains a complete diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 638e4245..0cf27585 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -867,14 +867,8 @@ template using cast_is_temporary_value_reference = bool_constant !std::is_base_of>::value >; - -NAMESPACE_END(detail) - -template T cast(const handle &handle) { - using type_caster = detail::make_caster; - static_assert(!detail::cast_is_temporary_value_reference::value, - "Unable to cast type to reference: value is local to type caster"); - type_caster conv; +template make_caster load_type(const handle &handle) { + make_caster conv; if (!conv.load(handle, true)) { #if defined(NDEBUG) throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)"); @@ -883,7 +877,16 @@ template T cast(const handle &handle) { (std::string) handle.get_type().str() + " to C++ type '" + type_id() + "''"); #endif } - return conv.operator typename type_caster::template cast_op_type(); + return conv; +} + +NAMESPACE_END(detail) + +template T cast(const handle &handle) { + static_assert(!detail::cast_is_temporary_value_reference::value, + "Unable to cast type to reference: value is local to type caster"); + using type_caster = detail::make_caster; + return detail::load_type(handle).operator typename type_caster::template cast_op_type(); } template object cast(const T &value, @@ -900,7 +903,7 @@ template T handle::cast() const { return pybind11::cast(*this); template <> inline void handle::cast() const { return; } template -typename std::enable_if::value || detail::move_if_unreferenced::value, T>::type move(object &&obj) { +detail::enable_if_t::value || detail::move_if_unreferenced::value, T> move(object &&obj) { if (obj.ref_count() > 1) #if defined(NDEBUG) throw cast_error("Unable to cast Python instance to C++ rvalue: instance has multiple references" @@ -910,18 +913,8 @@ typename std::enable_if::value || detail::move_if_unrefer " instance to C++ " + type_id() + " instance: instance has multiple references"); #endif - typedef detail::type_caster type_caster; - type_caster conv; - if (!conv.load(obj, true)) -#if defined(NDEBUG) - throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)"); -#else - throw cast_error("Unable to cast Python instance of type " + - (std::string) obj.get_type().str() + " to C++ type '" + type_id() + "''"); -#endif - // Move into a temporary and return that, because the reference may be a local value of `conv` - T ret = std::move(conv.operator T&()); + T ret = std::move(detail::load_type(obj).operator T&()); return ret; } @@ -930,24 +923,57 @@ typename std::enable_if::value || detail::move_if_unrefer // object has multiple references, but trying to copy will fail to compile. // - If both movable and copyable, check ref count: if 1, move; otherwise copy // - Otherwise (not movable), copy. -template typename std::enable_if::value, T>::type cast(object &&object) { +template detail::enable_if_t::value, T> cast(object &&object) { return move(std::move(object)); } -template typename std::enable_if::value, T>::type cast(object &&object) { +template detail::enable_if_t::value, T> cast(object &&object) { if (object.ref_count() > 1) return cast(object); else return move(std::move(object)); } -template typename std::enable_if::value, T>::type cast(object &&object) { +template detail::enable_if_t::value, T> cast(object &&object) { return cast(object); } +// Provide a ref_cast() with move support for objects (only participates for moveable types) +template detail::enable_if_t::value, T> +ref_cast(object &&object) { return cast(std::move(object)); } template T object::cast() const & { return pybind11::cast(*this); } template T object::cast() && { return pybind11::cast(std::move(*this)); } template <> inline void object::cast() const & { return; } template <> inline void object::cast() && { return; } +NAMESPACE_BEGIN(detail) + +struct overload_nothing {}; // Placeholder type for the unneeded (and dead code) static variable in the OVERLOAD_INT macro +template using overload_local_t = conditional_t< + cast_is_temporary_value_reference::value, intrinsic_t, overload_nothing>; + +template enable_if_t::value, T> storage_cast(intrinsic_t &v) { return v; } +template enable_if_t::value, T> storage_cast(intrinsic_t &v) { return &v; } + +// Trampoline use: for reference/pointer types to value-converted values, we do a value cast, then +// store the result in the given variable. For other types, this is a no-op. +template enable_if_t::value, T> cast_ref(object &&o, intrinsic_t &storage) { + using type_caster = make_caster; + using itype = intrinsic_t; + storage = std::move(load_type(o).operator typename type_caster::template cast_op_type()); + return storage_cast(storage); +} +template enable_if_t::value, T> cast_ref(object &&, overload_nothing &) { + pybind11_fail("Internal error: cast_ref fallback invoked"); } + +// Trampoline use: Having a pybind11::cast with an invalid reference type is going to static_assert, even +// though if it's in dead code, so we provide a "trampoline" to pybind11::cast that only does anything in +// cases where pybind11::cast is valid. +template enable_if_t::value, T> cast_safe(object &&o) { + return pybind11::cast(std::move(o)); } +template enable_if_t::value, T> cast_safe(object &&) { + pybind11_fail("Internal error: cast_safe fallback invoked"); } +template <> inline void cast_safe(object &&) {} + +NAMESPACE_END(detail) template function get_overload(const T *this_ptr, const char *name) { #define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \ pybind11::gil_scoped_acquire gil; \ pybind11::function overload = pybind11::get_overload(static_cast(this), name); \ - if (overload) \ - return overload(__VA_ARGS__).template cast(); } + if (overload) { \ + pybind11::object o = overload(__VA_ARGS__); \ + if (pybind11::detail::cast_is_temporary_value_reference::value) { \ + static pybind11::detail::overload_local_t local_value; \ + return pybind11::detail::cast_ref(std::move(o), local_value); \ + } \ + else return pybind11::detail::cast_safe(std::move(o)); \ + } \ + } #define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \ PYBIND11_OVERLOAD_INT(ret_type, cname, name, __VA_ARGS__) \ diff --git a/tests/test_virtual_functions.cpp b/tests/test_virtual_functions.cpp index 1581d740..0f8ed2af 100644 --- a/tests/test_virtual_functions.cpp +++ b/tests/test_virtual_functions.cpp @@ -21,14 +21,22 @@ public: virtual int run(int value) { py::print("Original implementation of " - "ExampleVirt::run(state={}, value={})"_s.format(state, value)); + "ExampleVirt::run(state={}, value={}, str1={}, str2={})"_s.format(state, value, get_string1(), *get_string2())); return state + value; } virtual bool run_bool() = 0; virtual void pure_virtual() = 0; + + // Returning a reference/pointer to a type converted from python (numbers, strings, etc.) is a + // bit trickier, because the actual int& or std::string& or whatever only exists temporarily, so + // we have to handle it specially in the trampoline class (see below). + virtual const std::string &get_string1() { return str1; } + virtual const std::string *get_string2() { return &str2; } + private: int state; + const std::string str1{"default1"}, str2{"default2"}; }; /* This is a wrapper class that must be generated */ @@ -65,6 +73,27 @@ public: in the previous line is needed for some compilers */ ); } + + // We can return reference types for compatibility with C++ virtual interfaces that do so, but + // note they have some significant limitations (see the documentation). + const std::string &get_string1() override { + PYBIND11_OVERLOAD( + const std::string &, /* Return type */ + ExampleVirt, /* Parent class */ + get_string1, /* Name of function */ + /* (no arguments) */ + ); + } + + const std::string *get_string2() override { + PYBIND11_OVERLOAD( + const std::string *, /* Return type */ + ExampleVirt, /* Parent class */ + get_string2, /* Name of function */ + /* (no arguments) */ + ); + } + }; class NonCopyable { diff --git a/tests/test_virtual_functions.py b/tests/test_virtual_functions.py index ef05de80..5d55d5ec 100644 --- a/tests/test_virtual_functions.py +++ b/tests/test_virtual_functions.py @@ -20,13 +20,23 @@ def test_override(capture, msg): print('ExtendedExampleVirt::run_bool()') return False + def get_string1(self): + return "override1" + def pure_virtual(self): print('ExtendedExampleVirt::pure_virtual(): %s' % self.data) + class ExtendedExampleVirt2(ExtendedExampleVirt): + def __init__(self, state): + super(ExtendedExampleVirt2, self).__init__(state + 1) + + def get_string2(self): + return "override2" + ex12 = ExampleVirt(10) with capture: assert runExampleVirt(ex12, 20) == 30 - assert capture == "Original implementation of ExampleVirt::run(state=10, value=20)" + assert capture == "Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)" with pytest.raises(RuntimeError) as excinfo: runExampleVirtVirtual(ex12) @@ -37,7 +47,7 @@ def test_override(capture, msg): assert runExampleVirt(ex12p, 20) == 32 assert capture == """ ExtendedExampleVirt::run(20), calling parent.. - Original implementation of ExampleVirt::run(state=11, value=21) + Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2) """ with capture: assert runExampleVirtBool(ex12p) is False @@ -46,11 +56,19 @@ def test_override(capture, msg): runExampleVirtVirtual(ex12p) assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world" + ex12p2 = ExtendedExampleVirt2(15) + with capture: + assert runExampleVirt(ex12p2, 50) == 68 + assert capture == """ + ExtendedExampleVirt::run(50), calling parent.. + Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2) + """ + cstats = ConstructorStats.get(ExampleVirt) - assert cstats.alive() == 2 - del ex12, ex12p + assert cstats.alive() == 3 + del ex12, ex12p, ex12p2 assert cstats.alive() == 0 - assert cstats.values() == ['10', '11'] + assert cstats.values() == ['10', '11', '17'] assert cstats.copy_constructions == 0 assert cstats.move_constructions >= 0