From b57281bb00582e1c8046aa6c39e0faf143742baa Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Mon, 3 Jul 2017 19:12:09 -0400 Subject: [PATCH] Use rvalue subcasting when casting an rvalue container This updates the std::tuple, std::pair and `stl.h` type casters to forward their contained value according to whether the container being cast is an lvalue or rvalue reference. This fixes an issue where subcaster casts were always called with a const lvalue which meant nested type casters didn't have the desired `cast()` overload invoked. For example, this caused Eigen values in a tuple to end up with a readonly flag (issue #935) and made it impossible to return a container of move-only types (issue #853). This fixes both issues by adding templated universal reference `cast()` methods to the various container types that forward container elements according to the container reference type. --- include/pybind11/cast.h | 31 +++++++++++----------- include/pybind11/stl.h | 48 +++++++++++++++++++++++----------- tests/pybind11_tests.h | 19 ++++++++++++++ tests/test_builtin_casters.cpp | 10 +++++++ tests/test_builtin_casters.py | 7 +++++ tests/test_stl.cpp | 40 ++++++++++++++++++++++++++++ tests/test_stl.py | 19 ++++++++++++++ tests/test_stl_binders.cpp | 5 ---- 8 files changed, 144 insertions(+), 35 deletions(-) diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index e814cab9..0c5c1bb1 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1256,9 +1256,9 @@ public: }; // Base implementation for std::tuple and std::pair -template class TupleType, typename... Tuple> class tuple_caster { - using type = TupleType; - static constexpr auto size = sizeof...(Tuple); +template class Tuple, typename... Ts> class tuple_caster { + using type = Tuple; + static constexpr auto size = sizeof...(Ts); using indices = make_index_sequence; public: @@ -1271,12 +1271,13 @@ public: return load_impl(seq, convert, indices{}); } - static handle cast(const type &src, return_value_policy policy, handle parent) { - return cast_impl(src, policy, parent, indices{}); + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + return cast_impl(std::forward(src), policy, parent, indices{}); } static PYBIND11_DESCR name() { - return type_descr(_("Tuple[") + detail::concat(make_caster::name()...) + _("]")); + return type_descr(_("Tuple[") + detail::concat(make_caster::name()...) + _("]")); } template using cast_op_type = type; @@ -1286,9 +1287,9 @@ public: protected: template - type implicit_cast(index_sequence) & { return type(cast_op(std::get(subcasters))...); } + type implicit_cast(index_sequence) & { return type(cast_op(std::get(subcasters))...); } template - type implicit_cast(index_sequence) && { return type(cast_op(std::move(std::get(subcasters)))...); } + type implicit_cast(index_sequence) && { return type(cast_op(std::move(std::get(subcasters)))...); } static constexpr bool load_impl(const sequence &, bool, index_sequence<>) { return true; } @@ -1301,10 +1302,10 @@ protected: } /* Implementation: Convert a C++ tuple into a Python tuple */ - template - static handle cast_impl(const type &src, return_value_policy policy, handle parent, index_sequence) { - std::array entries {{ - reinterpret_steal(make_caster::cast(std::get(src), policy, parent))... + template + static handle cast_impl(T &&src, return_value_policy policy, handle parent, index_sequence) { + std::array entries{{ + reinterpret_steal(make_caster::cast(std::get(std::forward(src)), policy, parent))... }}; for (const auto &entry: entries) if (!entry) @@ -1316,14 +1317,14 @@ protected: return result.release(); } - TupleType...> subcasters; + Tuple...> subcasters; }; template class type_caster> : public tuple_caster {}; -template class type_caster> - : public tuple_caster {}; +template class type_caster> + : public tuple_caster {}; /// Helper class which abstracts away certain actions. Users can provide specializations for /// custom holders, but it's only necessary if the type has a non-standard interface. diff --git a/include/pybind11/stl.h b/include/pybind11/stl.h index 535eb495..d07a81f9 100644 --- a/include/pybind11/stl.h +++ b/include/pybind11/stl.h @@ -49,6 +49,19 @@ NAMESPACE_BEGIN(pybind11) NAMESPACE_BEGIN(detail) +/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for +/// forwarding a container element). Typically used indirect via forwarded_type(), below. +template +using forwarded_type = conditional_t< + std::is_lvalue_reference::value, remove_reference_t &, remove_reference_t &&>; + +/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically +/// used for forwarding a container's elements. +template +forwarded_type forward_like(U &&u) { + return std::forward>(std::forward(u)); +} + template struct set_caster { using type = Type; using key_conv = make_caster; @@ -67,10 +80,11 @@ template struct set_caster { return true; } - static handle cast(const type &src, return_value_policy policy, handle parent) { + template + static handle cast(T &&src, return_value_policy policy, handle parent) { pybind11::set s; - for (auto const &value: src) { - auto value_ = reinterpret_steal(key_conv::cast(value, policy, parent)); + for (auto &value: src) { + auto value_ = reinterpret_steal(key_conv::cast(forward_like(value), policy, parent)); if (!value_ || !s.add(value_)) return handle(); } @@ -100,11 +114,12 @@ template struct map_caster { return true; } - static handle cast(const Type &src, return_value_policy policy, handle parent) { + template + static handle cast(T &&src, return_value_policy policy, handle parent) { dict d; - for (auto const &kv: src) { - auto key = reinterpret_steal(key_conv::cast(kv.first, policy, parent)); - auto value = reinterpret_steal(value_conv::cast(kv.second, policy, parent)); + for (auto &kv: src) { + auto key = reinterpret_steal(key_conv::cast(forward_like(kv.first), policy, parent)); + auto value = reinterpret_steal(value_conv::cast(forward_like(kv.second), policy, parent)); if (!key || !value) return handle(); d[key] = value; @@ -140,11 +155,12 @@ private: void reserve_maybe(sequence, void *) { } public: - static handle cast(const Type &src, return_value_policy policy, handle parent) { + template + static handle cast(T &&src, return_value_policy policy, handle parent) { list l(src.size()); size_t index = 0; - for (auto const &value: src) { - auto value_ = reinterpret_steal(value_conv::cast(value, policy, parent)); + for (auto &value: src) { + auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); if (!value_) return handle(); PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference @@ -193,11 +209,12 @@ public: return true; } - static handle cast(const ArrayType &src, return_value_policy policy, handle parent) { + template + static handle cast(T &&src, return_value_policy policy, handle parent) { list l(src.size()); size_t index = 0; - for (auto const &value: src) { - auto value_ = reinterpret_steal(value_conv::cast(value, policy, parent)); + for (auto &value: src) { + auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); if (!value_) return handle(); PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference @@ -230,10 +247,11 @@ template struct optional_caster { using value_conv = make_caster; - static handle cast(const T& src, return_value_policy policy, handle parent) { + template + static handle cast(T_ &&src, return_value_policy policy, handle parent) { if (!src) return none().inc_ref(); - return value_conv::cast(*src, policy, parent); + return value_conv::cast(*std::forward(src), policy, parent); } bool load(handle src, bool convert) { diff --git a/tests/pybind11_tests.h b/tests/pybind11_tests.h index dd8d1593..18672cd2 100644 --- a/tests/pybind11_tests.h +++ b/tests/pybind11_tests.h @@ -1,6 +1,11 @@ #pragma once #include +#if defined(_MSC_VER) && _MSC_VER < 1910 +// We get some really long type names here which causes MSVC 2015 to emit warnings +# pragma warning(disable: 4503) // warning C4503: decorated name length exceeded, name was truncated +#endif + namespace py = pybind11; using namespace pybind11::literals; @@ -43,3 +48,17 @@ public: IncType &operator=(const IncType &) = delete; IncType &operator=(IncType &&) = delete; }; + +/// Custom cast-only type that casts to a string "rvalue" or "lvalue" depending on the cast context. +/// Used to test recursive casters (e.g. std::tuple, stl containers). +struct RValueCaster {}; +NAMESPACE_BEGIN(pybind11) +NAMESPACE_BEGIN(detail) +template<> class type_caster { +public: + PYBIND11_TYPE_CASTER(RValueCaster, _("RValueCaster")); + static handle cast(RValueCaster &&, return_value_policy, handle) { return py::str("rvalue").release(); } + static handle cast(const RValueCaster &, return_value_policy, handle) { return py::str("lvalue").release(); } +}; +NAMESPACE_END(detail) +NAMESPACE_END(pybind11) diff --git a/tests/test_builtin_casters.cpp b/tests/test_builtin_casters.cpp index 67334023..bf972e19 100644 --- a/tests/test_builtin_casters.cpp +++ b/tests/test_builtin_casters.cpp @@ -86,6 +86,16 @@ TEST_SUBMODULE(builtin_casters, m) { return std::make_tuple(std::get<2>(input), std::get<1>(input), std::get<0>(input)); }, "Return a triple in reversed order"); m.def("empty_tuple", []() { return std::tuple<>(); }); + static std::pair lvpair; + static std::tuple lvtuple; + static std::pair>> lvnested; + m.def("rvalue_pair", []() { return std::make_pair(RValueCaster{}, RValueCaster{}); }); + m.def("lvalue_pair", []() -> const decltype(lvpair) & { return lvpair; }); + m.def("rvalue_tuple", []() { return std::make_tuple(RValueCaster{}, RValueCaster{}, RValueCaster{}); }); + m.def("lvalue_tuple", []() -> const decltype(lvtuple) & { return lvtuple; }); + m.def("rvalue_nested", []() { + return std::make_pair(RValueCaster{}, std::make_tuple(RValueCaster{}, std::make_pair(RValueCaster{}, RValueCaster{}))); }); + m.def("lvalue_nested", []() -> const decltype(lvnested) & { return lvnested; }); // test_builtins_cast_return_none m.def("return_none_string", []() -> std::string * { return nullptr; }); diff --git a/tests/test_builtin_casters.py b/tests/test_builtin_casters.py index 32eba45f..d7d49b69 100644 --- a/tests/test_builtin_casters.py +++ b/tests/test_builtin_casters.py @@ -201,6 +201,13 @@ def test_tuple(doc): Return a triple in reversed order """ + assert m.rvalue_pair() == ("rvalue", "rvalue") + assert m.lvalue_pair() == ("lvalue", "lvalue") + assert m.rvalue_tuple() == ("rvalue", "rvalue", "rvalue") + assert m.lvalue_tuple() == ("lvalue", "lvalue", "lvalue") + assert m.rvalue_nested() == ("rvalue", ("rvalue", ("rvalue", "rvalue"))) + assert m.lvalue_nested() == ("lvalue", ("lvalue", ("lvalue", "lvalue"))) + def test_builtins_cast_return_none(): """Casters produced with PYBIND11_TYPE_CASTER() should convert nullptr to None""" diff --git a/tests/test_stl.cpp b/tests/test_stl.cpp index 9762fb9a..839c1af0 100644 --- a/tests/test_stl.cpp +++ b/tests/test_stl.cpp @@ -61,6 +61,46 @@ TEST_SUBMODULE(stl, m) { return set.count("key1") && set.count("key2") && set.count("key3"); }); + // test_recursive_casting + m.def("cast_rv_vector", []() { return std::vector{2}; }); + m.def("cast_rv_array", []() { return std::array(); }); + // NB: map and set keys are `const`, so while we technically do move them (as `const Type &&`), + // casters don't typically do anything with that, which means they fall to the `const Type &` + // caster. + m.def("cast_rv_map", []() { return std::unordered_map{{"a", RValueCaster{}}}; }); + m.def("cast_rv_nested", []() { + std::vector>, 2>> v; + v.emplace_back(); // add an array + v.back()[0].emplace_back(); // add a map to the array + v.back()[0].back().emplace("b", RValueCaster{}); + v.back()[0].back().emplace("c", RValueCaster{}); + v.back()[1].emplace_back(); // add a map to the array + v.back()[1].back().emplace("a", RValueCaster{}); + return v; + }); + static std::vector lvv{2}; + static std::array lva; + static std::unordered_map lvm{{"a", RValueCaster{}}, {"b", RValueCaster{}}}; + static std::unordered_map>>> lvn; + lvn["a"].emplace_back(); // add a list + lvn["a"].back().emplace_back(); // add an array + lvn["a"].emplace_back(); // another list + lvn["a"].back().emplace_back(); // add an array + lvn["b"].emplace_back(); // add a list + lvn["b"].back().emplace_back(); // add an array + lvn["b"].back().emplace_back(); // add another array + m.def("cast_lv_vector", []() -> const decltype(lvv) & { return lvv; }); + m.def("cast_lv_array", []() -> const decltype(lva) & { return lva; }); + m.def("cast_lv_map", []() -> const decltype(lvm) & { return lvm; }); + m.def("cast_lv_nested", []() -> const decltype(lvn) & { return lvn; }); + // #853: + m.def("cast_unique_ptr_vector", []() { + std::vector> v; + v.emplace_back(new UserType{7}); + v.emplace_back(new UserType{42}); + return v; + }); + struct MoveOutContainer { struct Value { int value; }; diff --git a/tests/test_stl.py b/tests/test_stl.py index fdbca0da..0f019821 100644 --- a/tests/test_stl.py +++ b/tests/test_stl.py @@ -58,6 +58,25 @@ def test_set(doc): assert doc(m.load_set) == "load_set(arg0: Set[str]) -> bool" +def test_recursive_casting(): + """Tests that stl casters preserve lvalue/rvalue context for container values""" + assert m.cast_rv_vector() == ["rvalue", "rvalue"] + assert m.cast_lv_vector() == ["lvalue", "lvalue"] + assert m.cast_rv_array() == ["rvalue", "rvalue", "rvalue"] + assert m.cast_lv_array() == ["lvalue", "lvalue"] + assert m.cast_rv_map() == {"a": "rvalue"} + assert m.cast_lv_map() == {"a": "lvalue", "b": "lvalue"} + assert m.cast_rv_nested() == [[[{"b": "rvalue", "c": "rvalue"}], [{"a": "rvalue"}]]] + assert m.cast_lv_nested() == { + "a": [[["lvalue", "lvalue"]], [["lvalue", "lvalue"]]], + "b": [[["lvalue", "lvalue"], ["lvalue", "lvalue"]]] + } + + # Issue #853 test case: + z = m.cast_unique_ptr_vector() + assert z[0].value == 7 and z[1].value == 42 + + def test_move_out_container(): """Properties use the `reference_internal` policy by default. If the underlying function returns an rvalue, the policy is automatically changed to `move` to avoid referencing diff --git a/tests/test_stl_binders.cpp b/tests/test_stl_binders.cpp index f636c0b5..22ba16ea 100644 --- a/tests/test_stl_binders.cpp +++ b/tests/test_stl_binders.cpp @@ -15,11 +15,6 @@ #include #include -#ifdef _MSC_VER -// We get some really long type names here which causes MSVC to emit warnings -# pragma warning(disable: 4503) // warning C4503: decorated name length exceeded, name was truncated -#endif - class El { public: El() = delete;