diff --git a/docs/advanced/pycpp/numpy.rst b/docs/advanced/pycpp/numpy.rst index f9d6acbf..6bcc4671 100644 --- a/docs/advanced/pycpp/numpy.rst +++ b/docs/advanced/pycpp/numpy.rst @@ -340,12 +340,39 @@ To obtain the proxy from an ``array`` object, you must specify both the data type and number of dimensions as template arguments, such as ``auto r = myarray.mutable_unchecked()``. +If the number of dimensions is not known at compile time, you can omit the +dimensions template parameter (i.e. calling ``arr_t.unchecked()`` or +``arr.unchecked()``. This will give you a proxy object that works in the +same way, but results in less optimizable code and thus a small efficiency +loss in tight loops. + Note that the returned proxy object directly references the array's data, and only reads its shape, strides, and writeable flag when constructed. You must take care to ensure that the referenced array is not destroyed or reshaped for the duration of the returned object, typically by limiting the scope of the returned instance. +The returned proxy object supports some of the same methods as ``py::array`` so +that it can be used as a drop-in replacement for some existing, index-checked +uses of ``py::array``: + +- ``r.ndim()`` returns the number of dimensions + +- ``r.data(1, 2, ...)`` and ``r.mutable_data(1, 2, ...)``` returns a pointer to + the ``const T`` or ``T`` data, respectively, at the given indices. The + latter is only available to proxies obtained via ``a.mutable_unchecked()``. + +- ``itemsize()`` returns the size of an item in bytes, i.e. ``sizeof(T)``. + +- ``ndim()`` returns the number of dimensions. + +- ``shape(n)`` returns the size of dimension ``n`` + +- ``size()`` returns the total number of elements (i.e. the product of the shapes). + +- ``nbytes()`` returns the number of bytes used by the referenced elements + (i.e. ``itemsize()`` times ``size()``). + .. seealso:: The file :file:`tests/test_numpy_array.cpp` contains additional examples diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index a5f68cce..3227a12e 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -242,67 +242,107 @@ size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) { } /** Proxy class providing unsafe, unchecked const access to array data. This is constructed through - * the `unchecked()` method of `array` or the `unchecked()` method of `array_t`. + * the `unchecked()` method of `array` or the `unchecked()` method of `array_t`. `Dims` + * will be -1 for dimensions determined at runtime. */ -template +template class unchecked_reference { protected: + static constexpr bool Dynamic = Dims < 0; const unsigned char *data_; // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to - // make large performance gains on big, nested loops. - std::array shape_, strides_; + // make large performance gains on big, nested loops, but requires compile-time dimensions + conditional_t> + shape_, strides_; + const size_t dims_; friend class pybind11::array; - unchecked_reference(const void *data, const size_t *shape, const size_t *strides) - : data_{reinterpret_cast(data)} { - for (size_t i = 0; i < Dims; i++) { + // Constructor for compile-time dimensions: + template + unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t) + : data_{reinterpret_cast(data)}, dims_{Dims} { + for (size_t i = 0; i < dims_; i++) { shape_[i] = shape[i]; strides_[i] = strides[i]; } } + // Constructor for runtime dimensions: + template + unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t dims) + : data_{reinterpret_cast(data)}, shape_{shape}, strides_{strides}, dims_{dims} {} public: - /** Unchecked const reference access to data at the given indices. Omiting trailing indices - * is equivalent to specifying them as 0. + /** Unchecked const reference access to data at the given indices. For a compile-time known + * number of dimensions, this requires the correct number of arguments; for run-time + * dimensionality, this is not checked (and so is up to the caller to use safely). */ - template const T& operator()(Ix... index) const { - static_assert(sizeof...(Ix) <= Dims, "Invalid number of indices for unchecked array reference"); - return *reinterpret_cast(data_ + byte_offset_unsafe(strides_, size_t{index}...)); + template const T &operator()(Ix... index) const { + static_assert(sizeof...(Ix) == Dims || Dynamic, + "Invalid number of indices for unchecked array reference"); + return *reinterpret_cast(data_ + byte_offset_unsafe(strides_, size_t(index)...)); } /** Unchecked const reference access to data; this operator only participates if the reference * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`. */ - template > + template > const T &operator[](size_t index) const { return operator()(index); } + /// Pointer access to the data at the given indices. + template const T *data(Ix... ix) const { return &operator()(size_t(ix)...); } + + /// Returns the item size, i.e. sizeof(T) + constexpr static size_t itemsize() { return sizeof(T); } + /// Returns the shape (i.e. size) of dimension `dim` size_t shape(size_t dim) const { return shape_[dim]; } /// Returns the number of dimensions of the array - constexpr static size_t ndim() { return Dims; } + size_t ndim() const { return dims_; } + + /// Returns the total number of elements in the referenced array, i.e. the product of the shapes + template + enable_if_t size() const { + return std::accumulate(shape_.begin(), shape_.end(), (size_t) 1, std::multiplies()); + } + template + enable_if_t size() const { + return std::accumulate(shape_, shape_ + ndim(), (size_t) 1, std::multiplies()); + } + + /// Returns the total number of bytes used by the referenced data. Note that the actual span in + /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice). + size_t nbytes() const { + return size() * itemsize(); + } }; -template +template class unchecked_mutable_reference : public unchecked_reference { friend class pybind11::array; using ConstBase = unchecked_reference; using ConstBase::ConstBase; + using ConstBase::Dynamic; public: /// Mutable, unchecked access to data at the given indices. template T& operator()(Ix... index) { - static_assert(sizeof...(Ix) == Dims, "Invalid number of indices for unchecked array reference"); + static_assert(sizeof...(Ix) == Dims || Dynamic, + "Invalid number of indices for unchecked array reference"); return const_cast(ConstBase::operator()(index...)); } /** Mutable, unchecked access data at the given index; this operator only participates if the - * reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`. + * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is + * exactly equivalent to `obj(index)`. */ - template > + template > T &operator[](size_t index) { return operator()(index); } + + /// Mutable pointer access to the data at the given indices. + template T *mutable_data(Ix... ix) { return &operator()(size_t(ix)...); } }; template struct type_caster> { - static_assert(Dim == (size_t) -1 /* always fail */, "unchecked array proxy object is not castable"); + static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable"); }; template struct type_caster> : type_caster> {}; @@ -580,11 +620,11 @@ public: * care: the array must not be destroyed or reshaped for the duration of the returned object, * and the caller must take care not to access invalid dimensions or dimension indices. */ - template detail::unchecked_mutable_reference mutable_unchecked() { - if (ndim() != Dims) + template detail::unchecked_mutable_reference mutable_unchecked() { + if (Dims >= 0 && ndim() != (size_t) Dims) throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + "; expected " + std::to_string(Dims)); - return detail::unchecked_mutable_reference(mutable_data(), shape(), strides()); + return detail::unchecked_mutable_reference(mutable_data(), shape(), strides(), ndim()); } /** Returns a proxy object that provides const access to the array's data without bounds or @@ -593,11 +633,11 @@ public: * reshaped for the duration of the returned object, and the caller must take care not to access * invalid dimensions or dimension indices. */ - template detail::unchecked_reference unchecked() const { - if (ndim() != Dims) + template detail::unchecked_reference unchecked() const { + if (Dims >= 0 && ndim() != (size_t) Dims) throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + "; expected " + std::to_string(Dims)); - return detail::unchecked_reference(data(), shape(), strides()); + return detail::unchecked_reference(data(), shape(), strides(), ndim()); } /// Return a new view with all of the dimensions of length 1 removed @@ -625,7 +665,7 @@ protected: template size_t byte_offset(Ix... index) const { check_dimensions(index...); - return detail::byte_offset_unsafe(strides(), size_t{index}...); + return detail::byte_offset_unsafe(strides(), size_t(index)...); } void check_writeable() const { @@ -736,7 +776,7 @@ public: * care: the array must not be destroyed or reshaped for the duration of the returned object, * and the caller must take care not to access invalid dimensions or dimension indices. */ - template detail::unchecked_mutable_reference mutable_unchecked() { + template detail::unchecked_mutable_reference mutable_unchecked() { return array::mutable_unchecked(); } @@ -746,7 +786,7 @@ public: * for the duration of the returned object, and the caller must take care not to access invalid * dimensions or dimension indices. */ - template detail::unchecked_reference unchecked() const { + template detail::unchecked_reference unchecked() const { return array::unchecked(); } diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 461c9c00..cd648724 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -68,6 +68,21 @@ template arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at( sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \ sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); }); +template py::handle auxiliaries(T &&r, T2 &&r2) { + if (r.ndim() != 2) throw std::domain_error("error: ndim != 2"); + py::list l; + l.append(*r.data(0, 0)); + l.append(*r2.mutable_data(0, 0)); + l.append(r.data(0, 1) == r2.mutable_data(0, 1)); + l.append(r.ndim()); + l.append(r.itemsize()); + l.append(r.shape(0)); + l.append(r.shape(1)); + l.append(r.size()); + l.append(r.nbytes()); + return l.release(); +} + test_initializer numpy_array([](py::module &m) { auto sm = m.def_submodule("array"); @@ -191,6 +206,7 @@ test_initializer numpy_array([](py::module &m) { for (size_t j = 0; j < r.shape(1); j++) r(i, j) += v; }, py::arg().noconvert(), py::arg()); + sm.def("proxy_init3", [](double start) { py::array_t a({ 3, 3, 3 }); auto r = a.mutable_unchecked<3>(); @@ -216,4 +232,36 @@ test_initializer numpy_array([](py::module &m) { sumsq += r[i] * r(i); // Either notation works for a 1D array return sumsq; }); + + sm.def("proxy_auxiliaries2", [](py::array_t a) { + auto r = a.unchecked<2>(); + auto r2 = a.mutable_unchecked<2>(); + return auxiliaries(r, r2); + }); + + // Same as the above, but without a compile-time dimensions specification: + sm.def("proxy_add2_dyn", [](py::array_t a, double v) { + auto r = a.mutable_unchecked(); + if (r.ndim() != 2) throw std::domain_error("error: ndim != 2"); + for (size_t i = 0; i < r.shape(0); i++) + for (size_t j = 0; j < r.shape(1); j++) + r(i, j) += v; + }, py::arg().noconvert(), py::arg()); + sm.def("proxy_init3_dyn", [](double start) { + py::array_t a({ 3, 3, 3 }); + auto r = a.mutable_unchecked(); + if (r.ndim() != 3) throw std::domain_error("error: ndim != 3"); + for (size_t i = 0; i < r.shape(0); i++) + for (size_t j = 0; j < r.shape(1); j++) + for (size_t k = 0; k < r.shape(2); k++) + r(i, j, k) = start++; + return a; + }); + sm.def("proxy_auxiliaries2_dyn", [](py::array_t a) { + return auxiliaries(a.unchecked(), a.mutable_unchecked()); + }); + + sm.def("array_auxiliaries2", [](py::array_t a) { + return auxiliaries(a, a); + }); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 9081f8cd..14c25b37 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -341,8 +341,9 @@ def test_greedy_string_overload(): # issue 685 assert issue685(123) == "other" -def test_array_unchecked(msg): - from pybind11_tests.array import proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm +def test_array_unchecked_fixed_dims(msg): + from pybind11_tests.array import (proxy_add2, proxy_init3F, proxy_init3, proxy_squared_L2_norm, + proxy_auxiliaries2, array_auxiliaries2) z1 = np.array([[1, 2], [3, 4]], dtype='float64') proxy_add2(z1, 10) @@ -359,3 +360,20 @@ def test_array_unchecked(msg): assert proxy_squared_L2_norm(np.array(range(6))) == 55 assert proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55 + + assert proxy_auxiliaries2(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32] + assert proxy_auxiliaries2(z1) == array_auxiliaries2(z1) + + +def test_array_unchecked_dyn_dims(msg): + from pybind11_tests.array import (proxy_add2_dyn, proxy_init3_dyn, proxy_auxiliaries2_dyn, + array_auxiliaries2) + z1 = np.array([[1, 2], [3, 4]], dtype='float64') + proxy_add2_dyn(z1, 10) + assert np.all(z1 == [[11, 12], [13, 14]]) + + expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int') + assert np.all(proxy_init3_dyn(3.0) == expect_c) + + assert proxy_auxiliaries2_dyn(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32] + assert proxy_auxiliaries2_dyn(z1) == array_auxiliaries2(z1)