From 369e9b39377dfab610328a1f8de7d83273b82f55 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Thu, 13 Oct 2016 00:57:42 +0200 Subject: [PATCH] Permit creation of NumPy arrays with a "base" object that owns the data This patch adds an extra base handle parameter to most ``py::array`` and ``py::array_t<>`` constructors. If specified along with a pointer to data, the base object will be registered within NumPy, which increases the base's reference count. This feature is useful to create shallow copies of C++ or Python arrays while ensuring that the owners of the underlying can't be garbage collected while referenced by NumPy. The commit also adds a simple test function involving a ``wrap()`` function that creates shallow copies of various N-D arrays. --- include/pybind11/numpy.h | 87 +++++++++++++++++++++++++++----------- tests/test_numpy_array.cpp | 10 +++++ tests/test_numpy_array.py | 51 ++++++++++++++++++++++ 3 files changed, 124 insertions(+), 24 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 1125fd71..445a6368 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -22,8 +22,8 @@ #include #if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +# pragma warning(push) +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant #endif /* This will be true on all flat address space platforms and allows us to reduce the @@ -156,8 +156,10 @@ NAMESPACE_END(detail) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) #define PyArrayDescr_GET_(ptr, attr) \ (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) +#define PyArray_FLAGS_(ptr) \ + (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags) #define PyArray_CHKFLAGS_(ptr, flag) \ - (flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag)) + (flag == (PyArray_FLAGS_(ptr) & flag)) class dtype : public object { public: @@ -258,38 +260,62 @@ public: forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ }; - array(const pybind11::dtype& dt, const std::vector& shape, - const std::vector& strides, const void *ptr = nullptr) { + array(const pybind11::dtype &dt, const std::vector &shape, + const std::vector &strides, const void *ptr = nullptr, + handle base = handle()) { auto& api = detail::npy_api::get(); auto ndim = shape.size(); if (shape.size() != strides.size()) pybind11_fail("NumPy: shape ndim doesn't match strides ndim"); auto descr = dt; + + int flags = 0; + if (base && ptr) { + array base_array(base, true); + if (base_array.check()) + /* Copy flags from base (except baseship bit) */ + flags = base_array.flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_; + else + /* Writable by default, easy to downgrade later on if needed */ + flags = detail::npy_api::NPY_ARRAY_WRITEABLE_; + } + object tmp(api.PyArray_NewFromDescr_( api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(), - (Py_intptr_t *) strides.data(), const_cast(ptr), 0, nullptr), false); + (Py_intptr_t *) strides.data(), const_cast(ptr), flags, nullptr), false); if (!tmp) pybind11_fail("NumPy: unable to create array!"); - if (ptr) - tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false); + if (ptr) { + if (base) { + PyArray_GET_(tmp.ptr(), base) = base.inc_ref().ptr(); + } else { + tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false); + } + } m_ptr = tmp.release().ptr(); } - array(const pybind11::dtype& dt, const std::vector& shape, const void *ptr = nullptr) - : array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { } + array(const pybind11::dtype &dt, const std::vector &shape, + const void *ptr = nullptr, handle base = handle()) + : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { } - array(const pybind11::dtype& dt, size_t count, const void *ptr = nullptr) - : array(dt, std::vector { count }, ptr) { } + array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr, + handle base = handle()) + : array(dt, std::vector{ count }, ptr, base) { } template array(const std::vector& shape, - const std::vector& strides, const T* ptr) - : array(pybind11::dtype::of(), shape, strides, (void *) ptr) { } + const std::vector& strides, + const T* ptr, handle base = handle()) + : array(pybind11::dtype::of(), shape, strides, (void *) ptr, base) { } - template array(const std::vector& shape, const T* ptr) - : array(shape, default_strides(shape, sizeof(T)), ptr) { } + template + array(const std::vector &shape, const T *ptr, + handle base = handle()) + : array(shape, default_strides(shape, sizeof(T)), ptr, base) { } - template array(size_t count, const T* ptr) - : array(std::vector { count }, ptr) { } + template + array(size_t count, const T *ptr, handle base = handle()) + : array(std::vector{ count }, ptr, base) { } array(const buffer_info &info) : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } @@ -319,6 +345,11 @@ public: return (size_t) PyArray_GET_(m_ptr, nd); } + /// Base object + object base() const { + return object(PyArray_GET_(m_ptr, base), true); + } + /// Dimensions of the array const size_t* shape() const { return reinterpret_cast(PyArray_GET_(m_ptr, dimensions)); @@ -343,6 +374,11 @@ public: return strides()[dim]; } + /// Return the NumPy array flags + int flags() const { + return PyArray_FLAGS_(m_ptr); + } + /// If set, the array is writeable (otherwise the buffer is read-only) bool writeable() const { return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); @@ -436,14 +472,17 @@ public: array_t(const buffer_info& info) : array(info) { } - array_t(const std::vector& shape, const std::vector& strides, const T* ptr = nullptr) - : array(shape, strides, ptr) { } + array_t(const std::vector &shape, + const std::vector &strides, const T *ptr = nullptr, + handle base = handle()) + : array(shape, strides, ptr, base) { } - array_t(const std::vector& shape, const T* ptr = nullptr) - : array(shape, ptr) { } + array_t(const std::vector &shape, const T *ptr = nullptr, + handle base = handle()) + : array(shape, ptr, base) { } - array_t(size_t count, const T* ptr = nullptr) - : array(count, ptr) { } + array_t(size_t count, const T *ptr = nullptr, handle base = handle()) + : array(count, ptr, base) { } constexpr size_t itemsize() const { return sizeof(T); diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 37a89836..a6bf50de 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -99,4 +99,14 @@ test_initializer numpy_array([](py::module &m) { sm.def("make_c_array", [] { return py::array_t({ 2, 2 }, { 8, 4 }); }); + + sm.def("wrap", [](py::array a) { + return py::array( + a.dtype(), + std::vector(a.shape(), a.shape() + a.ndim()), + std::vector(a.strides(), a.strides() + a.ndim()), + a.data(), + a + ); + }); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 85775e4f..52350f69 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -149,6 +149,7 @@ def test_bounds_check(arr): index_at(arr, 0, 4) assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3' + @pytest.requires_numpy def test_make_c_f_array(): from pybind11_tests.array import ( @@ -158,3 +159,53 @@ def test_make_c_f_array(): assert not make_c_array().flags.f_contiguous assert make_f_array().flags.f_contiguous assert not make_f_array().flags.c_contiguous + + +@pytest.requires_numpy +def test_wrap(): + from pybind11_tests.array import wrap + + def assert_references(A, B): + assert A is not B + assert A.__array_interface__['data'][0] == \ + B.__array_interface__['data'][0] + assert A.shape == B.shape + assert A.strides == B.strides + assert A.flags.c_contiguous == B.flags.c_contiguous + assert A.flags.f_contiguous == B.flags.f_contiguous + assert A.flags.writeable == B.flags.writeable + assert A.flags.aligned == B.flags.aligned + assert A.flags.updateifcopy == B.flags.updateifcopy + assert np.all(A == B) + assert not B.flags.owndata + assert B.base is A + if A.flags.writeable and A.ndim == 2: + A[0, 0] = 1234 + assert B[0, 0] == 1234 + + A1 = np.array([1, 2], dtype=np.int16) + assert A1.flags.owndata and A1.base is None + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F') + assert A1.flags.owndata and A1.base is None + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C') + A1.flags.writeable = False + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = np.random.random((4, 4, 4)) + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = A1.transpose() + A2 = wrap(A1) + assert_references(A1, A2) + + A1 = A1.diagonal() + A2 = wrap(A1) + assert_references(A1, A2)