From 5027c4f95b65dbf76c3f72ce0fcda89716fb5f62 Mon Sep 17 00:00:00 2001 From: Sylvain Corlay Date: Wed, 16 Nov 2016 08:53:37 -0800 Subject: [PATCH] Switch NumPy variadic indexing to per-value arguments (#500) * Also added unsafe version without checks --- include/pybind11/numpy.h | 61 +++++++++++++++++++++++--------------- tests/test_numpy_array.cpp | 20 ++++++------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 4120f282..5309997c 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -444,31 +444,31 @@ public: /// Pointer to the contained data. If index is not provided, points to the /// beginning of the buffer. May throw if the index would lead to out of bounds access. - template const void* data(Ix&&... index) const { + template const void* data(Ix... index) const { return static_cast(PyArray_GET_(m_ptr, data) + offset_at(index...)); } /// Mutable pointer to the contained data. If index is not provided, points to the /// beginning of the buffer. May throw if the index would lead to out of bounds access. /// May throw if the array is not writeable. - template void* mutable_data(Ix&&... index) { + template void* mutable_data(Ix... index) { check_writeable(); return static_cast(PyArray_GET_(m_ptr, data) + offset_at(index...)); } /// Byte offset from beginning of the array to a given index (full or partial). /// May throw if the index would lead to out of bounds access. - template size_t offset_at(Ix&&... index) const { + template size_t offset_at(Ix... index) const { if (sizeof...(index) > ndim()) fail_dim_check(sizeof...(index), "too many indices for an array"); - return get_byte_offset(index...); + return byte_offset(size_t(index)...); } size_t offset_at() const { return 0; } /// Item count from beginning of the array to a given index (full or partial). /// May throw if the index would lead to out of bounds access. - template size_t index_at(Ix&&... index) const { + template size_t index_at(Ix... index) const { return offset_at(index...) / itemsize(); } @@ -493,18 +493,16 @@ protected: " (ndim = " + std::to_string(ndim()) + ")"); } - template size_t get_byte_offset(Ix&&... index) const { - const size_t idx[] = { (size_t) index... }; - if (!std::equal(idx + 0, idx + sizeof...(index), shape(), std::less{})) { - auto mismatch = std::mismatch(idx + 0, idx + sizeof...(index), shape(), std::less{}); - throw index_error(std::string("index ") + std::to_string(*mismatch.first) + - " is out of bounds for axis " + std::to_string(mismatch.first - idx) + - " with size " + std::to_string(*mismatch.second)); - } - return std::inner_product(idx + 0, idx + sizeof...(index), strides(), (size_t) 0); + template size_t byte_offset(Ix... index) const { + check_dimensions(index...); + return byte_offset_unsafe(index...); } - size_t get_byte_offset() const { return 0; } + template size_t byte_offset_unsafe(size_t i, Ix... index) const { + return i * strides()[dim] + byte_offset_unsafe(index...); + } + + template size_t byte_offset_unsafe() const { return 0; } void check_writeable() const { if (!writeable()) @@ -522,6 +520,23 @@ protected: } return strides; } + +protected: + + template void check_dimensions(Ix... index) const { + check_dimensions_impl(size_t(0), shape(), size_t(index)...); + } + + void check_dimensions_impl(size_t, const size_t*) const { } + + template void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const { + if (i >= *shape) { + throw index_error(std::string("index ") + std::to_string(i) + + " is out of bounds for axis " + std::to_string(axis) + + " with size " + std::to_string(*shape)); + } + check_dimensions_impl(axis + 1, shape + 1, index...); + } }; template class array_t : public array { @@ -548,32 +563,30 @@ public: return sizeof(T); } - template size_t index_at(Ix&... index) const { + template size_t index_at(Ix... index) const { return offset_at(index...) / itemsize(); } - template const T* data(Ix&&... index) const { + template const T* data(Ix... index) const { return static_cast(array::data(index...)); } - template T* mutable_data(Ix&&... index) { + template T* mutable_data(Ix... index) { return static_cast(array::mutable_data(index...)); } // Reference to element at a given index - template const T& at(Ix&&... index) const { + template const T& at(Ix... index) const { if (sizeof...(index) != ndim()) fail_dim_check(sizeof...(index), "index dimension mismatch"); - // not using offset_at() / index_at() here so as to avoid another dimension check - return *(static_cast(array::data()) + get_byte_offset(index...) / itemsize()); + return *(static_cast(array::data()) + byte_offset(size_t(index)...) / itemsize()); } // Mutable reference to element at a given index - template T& mutable_at(Ix&&... index) { + template T& mutable_at(Ix... index) { if (sizeof...(index) != ndim()) fail_dim_check(sizeof...(index), "index dimension mismatch"); - // not using offset_at() / index_at() here so as to avoid another dimension check - return *(static_cast(array::mutable_data()) + get_byte_offset(index...) / itemsize()); + return *(static_cast(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize()); } static bool is_non_null(PyObject *ptr) { return ptr != nullptr; } diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index f8be7220..df6377eb 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -18,11 +18,11 @@ using arr = py::array; using arr_t = py::array_t; -template arr data(const arr& a, Ix&&... index) { +template arr data(const arr& a, Ix... index) { return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...)); } -template arr data_t(const arr_t& a, Ix&&... index) { +template arr data_t(const arr_t& a, Ix... index) { return arr(a.size() - a.index_at(index...), a.data(index...)); } @@ -40,26 +40,26 @@ arr_t& mutate_data_t(arr_t& a) { return a; } -template arr& mutate_data(arr& a, Ix&&... index) { +template arr& mutate_data(arr& a, Ix... index) { auto ptr = (uint8_t *) a.mutable_data(index...); for (size_t i = 0; i < a.nbytes() - a.offset_at(index...); i++) ptr[i] = (uint8_t) (ptr[i] * 2); return a; } -template arr_t& mutate_data_t(arr_t& a, Ix&&... index) { +template arr_t& mutate_data_t(arr_t& a, Ix... index) { auto ptr = a.mutable_data(index...); for (size_t i = 0; i < a.size() - a.index_at(index...); i++) ptr[i]++; return a; } -template size_t index_at(const arr& a, Ix&&... idx) { return a.index_at(idx...); } -template size_t index_at_t(const arr_t& a, Ix&&... idx) { return a.index_at(idx...); } -template size_t offset_at(const arr& a, Ix&&... idx) { return a.offset_at(idx...); } -template size_t offset_at_t(const arr_t& a, Ix&&... idx) { return a.offset_at(idx...); } -template size_t at_t(const arr_t& a, Ix&&... idx) { return a.at(idx...); } -template arr_t& mutate_at_t(arr_t& a, Ix&&... idx) { a.mutable_at(idx...)++; return a; } +template size_t index_at(const arr& a, Ix... idx) { return a.index_at(idx...); } +template size_t index_at_t(const arr_t& a, Ix... idx) { return a.index_at(idx...); } +template size_t offset_at(const arr& a, Ix... idx) { return a.offset_at(idx...); } +template size_t offset_at_t(const arr_t& a, Ix... idx) { return a.offset_at(idx...); } +template size_t at_t(const arr_t& a, Ix... idx) { return a.at(idx...); } +template arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(idx...)++; return a; } #define def_index_fn(name, type) \ sm.def(#name, [](type a) { return name(a); }); \