mirror of
https://github.com/RYDE-WORK/pybind11.git
synced 2026-02-08 08:33:14 +08:00
array: add direct data access and indexing methods
This commit is contained in:
parent
91b3d681ad
commit
f2a0ad5855
@ -26,8 +26,14 @@
|
|||||||
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/* This will be true on all flat address space platforms and allows us to reduce the
|
||||||
|
whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
|
||||||
|
and dimension types (e.g. shape, strides, indexing), instead of inflicting this
|
||||||
|
upon the library user. */
|
||||||
|
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
||||||
|
|
||||||
NAMESPACE_BEGIN(pybind11)
|
NAMESPACE_BEGIN(pybind11)
|
||||||
namespace detail {
|
NAMESPACE_BEGIN(detail)
|
||||||
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
template <typename type, typename SFINAE = void> struct npy_format_descriptor { };
|
||||||
template <typename type> struct is_pod_struct;
|
template <typename type> struct is_pod_struct;
|
||||||
|
|
||||||
@ -141,10 +147,12 @@ private:
|
|||||||
return api;
|
return api;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
NAMESPACE_END(detail)
|
||||||
|
|
||||||
#define PyArray_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
|
#define PyArray_GET_(ptr, attr) \
|
||||||
#define PyArrayDescr_GET_(ptr, attr) (reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
(reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->attr)
|
||||||
|
#define PyArrayDescr_GET_(ptr, attr) \
|
||||||
|
(reinterpret_cast<::pybind11::detail::PyArrayDescr_Proxy*>(ptr)->attr)
|
||||||
#define PyArray_CHKFLAGS_(ptr, flag) \
|
#define PyArray_CHKFLAGS_(ptr, flag) \
|
||||||
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
|
(flag == (reinterpret_cast<::pybind11::detail::PyArray_Proxy*>(ptr)->flags & flag))
|
||||||
|
|
||||||
@ -250,7 +258,7 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
array(const pybind11::dtype& dt, const std::vector<size_t>& shape,
|
array(const pybind11::dtype& dt, const std::vector<size_t>& shape,
|
||||||
const std::vector<size_t>& strides, void *ptr = nullptr) {
|
const std::vector<size_t>& strides, const void *ptr = nullptr) {
|
||||||
auto& api = detail::npy_api::get();
|
auto& api = detail::npy_api::get();
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
if (shape.size() != strides.size())
|
if (shape.size() != strides.size())
|
||||||
@ -258,7 +266,7 @@ public:
|
|||||||
auto descr = dt;
|
auto descr = dt;
|
||||||
object tmp(api.PyArray_NewFromDescr_(
|
object tmp(api.PyArray_NewFromDescr_(
|
||||||
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
|
api.PyArray_Type_, descr.release().ptr(), (int) ndim, (Py_intptr_t *) shape.data(),
|
||||||
(Py_intptr_t *) strides.data(), ptr, 0, nullptr), false);
|
(Py_intptr_t *) strides.data(), const_cast<void *>(ptr), 0, nullptr), false);
|
||||||
if (!tmp)
|
if (!tmp)
|
||||||
pybind11_fail("NumPy: unable to create array!");
|
pybind11_fail("NumPy: unable to create array!");
|
||||||
if (ptr)
|
if (ptr)
|
||||||
@ -266,20 +274,20 @@ public:
|
|||||||
m_ptr = tmp.release().ptr();
|
m_ptr = tmp.release().ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
array(const pybind11::dtype& dt, const std::vector<size_t>& shape, void *ptr = nullptr)
|
array(const pybind11::dtype& dt, const std::vector<size_t>& shape, const void *ptr = nullptr)
|
||||||
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { }
|
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr) { }
|
||||||
|
|
||||||
array(const pybind11::dtype& dt, size_t count, void *ptr = nullptr)
|
array(const pybind11::dtype& dt, size_t count, const void *ptr = nullptr)
|
||||||
: array(dt, std::vector<size_t> { count }, ptr) { }
|
: array(dt, std::vector<size_t> { count }, ptr) { }
|
||||||
|
|
||||||
template<typename T> array(const std::vector<size_t>& shape,
|
template<typename T> array(const std::vector<size_t>& shape,
|
||||||
const std::vector<size_t>& strides, T* ptr)
|
const std::vector<size_t>& strides, const T* ptr)
|
||||||
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr) { }
|
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr) { }
|
||||||
|
|
||||||
template<typename T> array(const std::vector<size_t>& shape, T* ptr)
|
template<typename T> array(const std::vector<size_t>& shape, const T* ptr)
|
||||||
: array(shape, default_strides(shape, sizeof(T)), ptr) { }
|
: array(shape, default_strides(shape, sizeof(T)), ptr) { }
|
||||||
|
|
||||||
template<typename T> array(size_t count, T* ptr)
|
template<typename T> array(size_t count, const T* ptr)
|
||||||
: array(std::vector<size_t> { count }, ptr) { }
|
: array(std::vector<size_t> { count }, ptr) { }
|
||||||
|
|
||||||
array(const buffer_info &info)
|
array(const buffer_info &info)
|
||||||
@ -312,27 +320,25 @@ public:
|
|||||||
|
|
||||||
/// Dimensions of the array
|
/// Dimensions of the array
|
||||||
const size_t* shape() const {
|
const size_t* shape() const {
|
||||||
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
|
||||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
|
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, dimensions));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dimension along a given axis
|
/// Dimension along a given axis
|
||||||
size_t shape(size_t dim) const {
|
size_t shape(size_t dim) const {
|
||||||
if (dim >= ndim())
|
if (dim >= ndim())
|
||||||
pybind11_fail("NumPy: attempted to index shape beyond ndim");
|
fail_dim_check(dim, "invalid axis");
|
||||||
return shape()[dim];
|
return shape()[dim];
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Strides of the array
|
/// Strides of the array
|
||||||
const size_t* strides() const {
|
const size_t* strides() const {
|
||||||
static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
|
|
||||||
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
|
return reinterpret_cast<const size_t *>(PyArray_GET_(m_ptr, strides));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stride along a given axis
|
/// Stride along a given axis
|
||||||
size_t strides(size_t dim) const {
|
size_t strides(size_t dim) const {
|
||||||
if (dim >= ndim())
|
if (dim >= ndim())
|
||||||
pybind11_fail("NumPy: attempted to index strides beyond ndim");
|
fail_dim_check(dim, "invalid axis");
|
||||||
return strides()[dim];
|
return strides()[dim];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -346,20 +352,61 @@ public:
|
|||||||
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
return PyArray_CHKFLAGS_(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Direct pointer to contained buffer
|
/// Pointer to the contained data. If index is not provided, points to the
|
||||||
const void* data() const {
|
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
||||||
return reinterpret_cast<const void *>(PyArray_GET_(m_ptr, data));
|
template<typename... Ix> const void* data(Ix&&... index) const {
|
||||||
|
return static_cast<const void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Direct mutable pointer to contained buffer (checks writeable flag)
|
/// Mutable pointer to the contained data. If index is not provided, points to the
|
||||||
void* mutable_data() {
|
/// beginning of the buffer. May throw if the index would lead to out of bounds access.
|
||||||
if (!writeable())
|
/// May throw if the array is not writeable.
|
||||||
pybind11_fail("NumPy: cannot get mutable data of a read-only array");
|
template<typename... Ix> void* mutable_data(Ix&&... index) {
|
||||||
return reinterpret_cast<void *>(PyArray_GET_(m_ptr, data));
|
check_writeable();
|
||||||
|
return static_cast<void *>(PyArray_GET_(m_ptr, data) + offset_at(index...));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Byte offset from beginning of the array to a given index (full or partial).
|
||||||
|
/// May throw if the index would lead to out of bounds access.
|
||||||
|
template<typename... Ix> size_t offset_at(Ix&&... index) const {
|
||||||
|
if (sizeof...(index) > ndim())
|
||||||
|
fail_dim_check(sizeof...(index), "too many indices for an array");
|
||||||
|
return get_byte_offset(index...);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t offset_at() const { return 0; }
|
||||||
|
|
||||||
|
/// Item count from beginning of the array to a given index (full or partial).
|
||||||
|
/// May throw if the index would lead to out of bounds access.
|
||||||
|
template<typename... Ix> size_t index_at(Ix&&... index) const {
|
||||||
|
return offset_at(index...) / itemsize();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
template <typename T, typename SFINAE> friend struct detail::npy_format_descriptor;
|
template<typename, typename> friend struct detail::npy_format_descriptor;
|
||||||
|
|
||||||
|
void fail_dim_check(size_t dim, const std::string& msg) const {
|
||||||
|
throw index_error(msg + ": " + std::to_string(dim) +
|
||||||
|
" (ndim = " + std::to_string(ndim()) + ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename... Ix> size_t get_byte_offset(Ix&&... index) const {
|
||||||
|
const size_t idx[] = { (size_t) index... };
|
||||||
|
if (!std::equal(idx + 0, idx + sizeof...(index), shape(), std::less<size_t>{})) {
|
||||||
|
auto mismatch = std::mismatch(idx + 0, idx + sizeof...(index), shape(), std::less<size_t>{});
|
||||||
|
throw index_error(std::string("index ") + std::to_string(*mismatch.first) +
|
||||||
|
" is out of bounds for axis " + std::to_string(mismatch.first - idx) +
|
||||||
|
" with size " + std::to_string(*mismatch.second));
|
||||||
|
}
|
||||||
|
return std::inner_product(idx + 0, idx + sizeof...(index), strides(), (size_t) 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_byte_offset() const { return 0; }
|
||||||
|
|
||||||
|
void check_writeable() const {
|
||||||
|
if (!writeable())
|
||||||
|
throw std::runtime_error("array is not writeable");
|
||||||
|
}
|
||||||
|
|
||||||
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
|
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
@ -382,23 +429,45 @@ public:
|
|||||||
|
|
||||||
array_t(const buffer_info& info) : array(info) { }
|
array_t(const buffer_info& info) : array(info) { }
|
||||||
|
|
||||||
array_t(const std::vector<size_t>& shape, const std::vector<size_t>& strides, T* ptr = nullptr)
|
array_t(const std::vector<size_t>& shape, const std::vector<size_t>& strides, const T* ptr = nullptr)
|
||||||
: array(shape, strides, ptr) { }
|
: array(shape, strides, ptr) { }
|
||||||
|
|
||||||
array_t(const std::vector<size_t>& shape, T* ptr = nullptr)
|
array_t(const std::vector<size_t>& shape, const T* ptr = nullptr)
|
||||||
: array(shape, ptr) { }
|
: array(shape, ptr) { }
|
||||||
|
|
||||||
array_t(size_t count, T* ptr = nullptr)
|
array_t(size_t count, const T* ptr = nullptr)
|
||||||
: array(count, ptr) { }
|
: array(count, ptr) { }
|
||||||
|
|
||||||
const T* data() const {
|
constexpr size_t itemsize() const {
|
||||||
return reinterpret_cast<const T *>(PyArray_GET_(m_ptr, data));
|
return sizeof(T);
|
||||||
}
|
}
|
||||||
|
|
||||||
T* mutable_data() {
|
template<typename... Ix> size_t index_at(Ix&... index) const {
|
||||||
if (!writeable())
|
return offset_at(index...) / itemsize();
|
||||||
pybind11_fail("NumPy: cannot get mutable data of a read-only array");
|
}
|
||||||
return reinterpret_cast<T *>(PyArray_GET_(m_ptr, data));
|
|
||||||
|
template<typename... Ix> const T* data(Ix&&... index) const {
|
||||||
|
return static_cast<const T*>(array::data(index...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename... Ix> T* mutable_data(Ix&&... index) {
|
||||||
|
return static_cast<T*>(array::mutable_data(index...));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reference to element at a given index
|
||||||
|
template<typename... Ix> const T& at(Ix&&... index) const {
|
||||||
|
if (sizeof...(index) != ndim())
|
||||||
|
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||||
|
// not using offset_at() / index_at() here so as to avoid another dimension check
|
||||||
|
return *(static_cast<const T*>(array::data()) + get_byte_offset(index...) / itemsize());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutable reference to element at a given index
|
||||||
|
template<typename... Ix> T& mutable_at(Ix&&... index) {
|
||||||
|
if (sizeof...(index) != ndim())
|
||||||
|
fail_dim_check(sizeof...(index), "index dimension mismatch");
|
||||||
|
// not using offset_at() / index_at() here so as to avoid another dimension check
|
||||||
|
return *(static_cast<T*>(array::mutable_data()) + get_byte_offset(index...) / itemsize());
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
|
static bool is_non_null(PyObject *ptr) { return ptr != nullptr; }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user