diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 6e0785e8..0768343f 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -26,8 +26,14 @@ #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant #endif +/* This will be true on all flat address space platforms and allows us to reduce the + whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size + and dimension types (e.g. shape, strides, indexing), instead of inflicting this + upon the library user. */ +static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); + NAMESPACE_BEGIN(pybind11) -namespace detail { +NAMESPACE_BEGIN(detail) template struct npy_format_descriptor { }; template struct is_pod_struct; @@ -141,10 +147,12 @@ private: return api; } }; -} +NAMESPACE_END(detail) -#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) -#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) +#define PyArray_GET_(ptr, attr) \ + (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr) +#define PyArrayDescr_GET_(ptr, attr) \ + (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr) #define PyArray_CHKFLAGS_(ptr, flag) \ (flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag)) @@ -250,7 +258,7 @@ public: }; array(const pybind11::dtype& dt, const std::vector& shape, - const std::vector& strides, void *ptr = nullptr) { + const std::vector& strides, const void *ptr = nullptr) { auto& api = detail::npy_api::get(); auto ndim = shape.size(); if (shape.size() != strides.size()) @@ -258,7 +266,7 @@ public: auto descr = dt; object tmp(api.PyArray_NewFromDescr_( api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(), - (Py_intptr_t *) strides.data(), ptr, 0, nullptr), false); + (Py_intptr_t *) strides.data(), const_cast(ptr), 0, nullptr), false); if (!tmp) pybind11_fail("NumPy: unable to create array!"); if (ptr) @@ -266,20 +274,20 @@ public: m_ptr = tmp.release().ptr(); } - array(const pybind11::dtype& dt, const std::vector& shape, void *ptr = nullptr) + array(const pybind11::dtype& dt, const std::vector& shape, const void *ptr = nullptr) : array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { } - array(const pybind11::dtype& dt, size_t count, void *ptr = nullptr) + array(const pybind11::dtype& dt, size_t count, const void *ptr = nullptr) : array(dt, std::vector { count }, ptr) { } template array(const std::vector& shape, - const std::vector& strides, T* ptr) + const std::vector& strides, const T* ptr) : array(pybind11::dtype::of(), shape, strides, (void *) ptr) { } - template array(const std::vector& shape, T* ptr) + template array(const std::vector& shape, const T* ptr) : array(shape, default_strides(shape, sizeof(T)), ptr) { } - template array(size_t count, T* ptr) + template array(size_t count, const T* ptr) : array(std::vector { count }, ptr) { } array(const buffer_info &info) @@ -312,27 +320,25 @@ public: /// Dimensions of the array const size_t* shape() const { - static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); return reinterpret_cast(PyArray_GET_(m_ptr, dimensions)); } /// Dimension along a given axis size_t shape(size_t dim) const { if (dim >= ndim()) - pybind11_fail("NumPy: attempted to index shape beyond ndim"); + fail_dim_check(dim, "invalid axis"); return shape()[dim]; } /// Strides of the array const size_t* strides() const { - static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t"); return reinterpret_cast(PyArray_GET_(m_ptr, strides)); } /// Stride along a given axis size_t strides(size_t dim) const { if (dim >= ndim()) - pybind11_fail("NumPy: attempted to index strides beyond ndim"); + fail_dim_check(dim, "invalid axis"); return strides()[dim]; } @@ -346,20 +352,61 @@ public: return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); } - /// Direct pointer to contained buffer - const void* data() const { - return reinterpret_cast(PyArray_GET_(m_ptr, data)); + /// Pointer to the contained data. If index is not provided, points to the + /// beginning of the buffer. May throw if the index would lead to out of bounds access. + template const void* data(Ix&&... index) const { + return static_cast(PyArray_GET_(m_ptr, data) + offset_at(index...)); } - /// Direct mutable pointer to contained buffer (checks writeable flag) - void* mutable_data() { - if (!writeable()) - pybind11_fail("NumPy: cannot get mutable data of a read-only array"); - return reinterpret_cast(PyArray_GET_(m_ptr, data)); + /// Mutable pointer to the contained data. If index is not provided, points to the + /// beginning of the buffer. May throw if the index would lead to out of bounds access. + /// May throw if the array is not writeable. + template void* mutable_data(Ix&&... index) { + check_writeable(); + return static_cast(PyArray_GET_(m_ptr, data) + offset_at(index...)); + } + + /// Byte offset from beginning of the array to a given index (full or partial). + /// May throw if the index would lead to out of bounds access. + template size_t offset_at(Ix&&... index) const { + if (sizeof...(index) > ndim()) + fail_dim_check(sizeof...(index), "too many indices for an array"); + return get_byte_offset(index...); + } + + size_t offset_at() const { return 0; } + + /// Item count from beginning of the array to a given index (full or partial). + /// May throw if the index would lead to out of bounds access. + template size_t index_at(Ix&&... index) const { + return offset_at(index...) / itemsize(); } protected: - template friend struct detail::npy_format_descriptor; + template friend struct detail::npy_format_descriptor; + + void fail_dim_check(size_t dim, const std::string& msg) const { + throw index_error(msg + ": " + std::to_string(dim) + + " (ndim = " + std::to_string(ndim()) + ")"); + } + + template size_t get_byte_offset(Ix&&... index) const { + const size_t idx[] = { (size_t) index... }; + if (!std::equal(idx + 0, idx + sizeof...(index), shape(), std::less{})) { + auto mismatch = std::mismatch(idx + 0, idx + sizeof...(index), shape(), std::less{}); + throw index_error(std::string("index ") + std::to_string(*mismatch.first) + + " is out of bounds for axis " + std::to_string(mismatch.first - idx) + + " with size " + std::to_string(*mismatch.second)); + } + return std::inner_product(idx + 0, idx + sizeof...(index), strides(), (size_t) 0); + } + + size_t get_byte_offset() const { return 0; } + + void check_writeable() const { + if (!writeable()) + throw std::runtime_error("array is not writeable"); + } static std::vector default_strides(const std::vector& shape, size_t itemsize) { auto ndim = shape.size(); @@ -382,23 +429,45 @@ public: array_t(const buffer_info& info) : array(info) { } - array_t(const std::vector& shape, const std::vector& strides, T* ptr = nullptr) + array_t(const std::vector& shape, const std::vector& strides, const T* ptr = nullptr) : array(shape, strides, ptr) { } - array_t(const std::vector& shape, T* ptr = nullptr) + array_t(const std::vector& shape, const T* ptr = nullptr) : array(shape, ptr) { } - array_t(size_t count, T* ptr = nullptr) + array_t(size_t count, const T* ptr = nullptr) : array(count, ptr) { } - const T* data() const { - return reinterpret_cast(PyArray_GET_(m_ptr, data)); + constexpr size_t itemsize() const { + return sizeof(T); } - T* mutable_data() { - if (!writeable()) - pybind11_fail("NumPy: cannot get mutable data of a read-only array"); - return reinterpret_cast(PyArray_GET_(m_ptr, data)); + template size_t index_at(Ix&... index) const { + return offset_at(index...) / itemsize(); + } + + template const T* data(Ix&&... index) const { + return static_cast(array::data(index...)); + } + + template T* mutable_data(Ix&&... index) { + return static_cast(array::mutable_data(index...)); + } + + // Reference to element at a given index + template const T& at(Ix&&... index) const { + if (sizeof...(index) != ndim()) + fail_dim_check(sizeof...(index), "index dimension mismatch"); + // not using offset_at() / index_at() here so as to avoid another dimension check + return *(static_cast(array::data()) + get_byte_offset(index...) / itemsize()); + } + + // Mutable reference to element at a given index + template T& mutable_at(Ix&&... index) { + if (sizeof...(index) != ndim()) + fail_dim_check(sizeof...(index), "index dimension mismatch"); + // not using offset_at() / index_at() here so as to avoid another dimension check + return *(static_cast(array::mutable_data()) + get_byte_offset(index...) / itemsize()); } static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }