mirror of
https://github.com/RYDE-WORK/pybind11.git
synced 2026-02-04 22:46:15 +08:00
Accept abitrary containers and iterators for shape/strides
This adds support for constructing `buffer_info` and `array`s using arbitrary containers or iterator pairs instead of requiring a vector. This is primarily needed by PR #782 (which makes strides signed to properly support negative strides, and will likely also make shape and itemsize to avoid mixed integer issues), but also needs to preserve backwards compatibility with 2.1 and earlier which accepts the strides parameter as a vector of size_t's. Rather than adding nearly duplicate constructors for each stride-taking constructor, it seems nicer to simply allow any type of container (or iterator pairs). This works by replacing the existing vector arguments with a new `detail::any_container` class that handles implicit conversion of arbitrary containers into a vector of the desired type. It can also be explicitly instantiated with a pair of iterators (e.g. by passing {begin, end} instead of the container).
This commit is contained in:
parent
dbb4c5b531
commit
5f38386293
@ -7,7 +7,7 @@
|
|||||||
BSD-style license that can be found in the LICENSE file.
|
BSD-style license that can be found in the LICENSE file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
@ -26,25 +26,22 @@ struct buffer_info {
|
|||||||
buffer_info() { }
|
buffer_info() { }
|
||||||
|
|
||||||
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
|
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
|
||||||
const std::vector<size_t> &shape, const std::vector<size_t> &strides)
|
detail::any_container<size_t> shape_in, detail::any_container<size_t> strides_in)
|
||||||
: ptr(ptr), itemsize(itemsize), size(1), format(format),
|
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
|
||||||
ndim(ndim), shape(shape), strides(strides) {
|
shape(std::move(shape_in)), strides(std::move(strides_in)) {
|
||||||
|
if (ndim != shape.size() || ndim != strides.size())
|
||||||
|
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
|
||||||
for (size_t i = 0; i < ndim; ++i)
|
for (size_t i = 0; i < ndim; ++i)
|
||||||
size *= shape[i];
|
size *= shape[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t size)
|
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t size)
|
||||||
: buffer_info(ptr, itemsize, format, 1, std::vector<size_t> { size },
|
: buffer_info(ptr, itemsize, format, 1, { size }, { itemsize }) { }
|
||||||
std::vector<size_t> { itemsize }) { }
|
|
||||||
|
|
||||||
explicit buffer_info(Py_buffer *view, bool ownview = true)
|
explicit buffer_info(Py_buffer *view, bool ownview_in = true)
|
||||||
: ptr(view->buf), itemsize((size_t) view->itemsize), size(1), format(view->format),
|
: buffer_info(view->buf, (size_t) view->itemsize, view->format, (size_t) view->ndim,
|
||||||
ndim((size_t) view->ndim), shape((size_t) view->ndim), strides((size_t) view->ndim), view(view), ownview(ownview) {
|
{view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
|
||||||
for (size_t i = 0; i < (size_t) view->ndim; ++i) {
|
ownview = ownview_in;
|
||||||
shape[i] = (size_t) view->shape[i];
|
|
||||||
strides[i] = (size_t) view->strides[i];
|
|
||||||
size *= shape[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer_info(const buffer_info &) = delete;
|
buffer_info(const buffer_info &) = delete;
|
||||||
|
|||||||
@ -490,6 +490,12 @@ struct is_instantiation<Class, Class<Us...>> : std::true_type { };
|
|||||||
/// Check if T is std::shared_ptr<U> where U can be anything
|
/// Check if T is std::shared_ptr<U> where U can be anything
|
||||||
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
|
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
|
||||||
|
|
||||||
|
/// Check if T looks like an input iterator
|
||||||
|
template <typename T, typename = void> struct is_input_iterator : std::false_type {};
|
||||||
|
template <typename T>
|
||||||
|
struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
|
||||||
|
: std::true_type {};
|
||||||
|
|
||||||
/// Ignore that a variable is unused in compiler warnings
|
/// Ignore that a variable is unused in compiler warnings
|
||||||
inline void ignore_unused(const int *) { }
|
inline void ignore_unused(const int *) { }
|
||||||
|
|
||||||
@ -651,4 +657,46 @@ static constexpr auto const_ = std::true_type{};
|
|||||||
|
|
||||||
#endif // overload_cast
|
#endif // overload_cast
|
||||||
|
|
||||||
|
NAMESPACE_BEGIN(detail)
|
||||||
|
|
||||||
|
// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
|
||||||
|
// any standard container (or C-style array) supporting std::begin/std::end.
|
||||||
|
template <typename T>
|
||||||
|
class any_container {
|
||||||
|
std::vector<T> v;
|
||||||
|
public:
|
||||||
|
any_container() = default;
|
||||||
|
|
||||||
|
// Can construct from a pair of iterators
|
||||||
|
template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
|
||||||
|
any_container(It first, It last) : v(first, last) { }
|
||||||
|
|
||||||
|
// Implicit conversion constructor from any arbitrary container type with values convertible to T
|
||||||
|
template <typename Container, typename = enable_if_t<std::is_convertible<decltype(*std::begin(std::declval<const Container &>())), T>::value>>
|
||||||
|
any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { }
|
||||||
|
|
||||||
|
// initializer_list's aren't deducible, so don't get matched by the above template; we need this
|
||||||
|
// to explicitly allow implicit conversion from one:
|
||||||
|
template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
|
||||||
|
any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) { }
|
||||||
|
|
||||||
|
// Avoid copying if given an rvalue vector of the correct type.
|
||||||
|
any_container(std::vector<T> &&v) : v(std::move(v)) { }
|
||||||
|
|
||||||
|
// Moves the vector out of an rvalue any_container
|
||||||
|
operator std::vector<T> &&() && { return std::move(v); }
|
||||||
|
|
||||||
|
// Dereferencing obtains a reference to the underlying vector
|
||||||
|
std::vector<T> &operator*() { return v; }
|
||||||
|
const std::vector<T> &operator*() const { return v; }
|
||||||
|
|
||||||
|
// -> lets you call methods on the underlying vector
|
||||||
|
std::vector<T> *operator->() { return &v; }
|
||||||
|
const std::vector<T> *operator->() const { return &v; }
|
||||||
|
};
|
||||||
|
|
||||||
|
NAMESPACE_END(detail)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
NAMESPACE_END(pybind11)
|
NAMESPACE_END(pybind11)
|
||||||
|
|||||||
@ -201,18 +201,13 @@ template <typename Type_> struct EigenProps {
|
|||||||
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
|
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
|
||||||
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
|
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
|
||||||
constexpr size_t elem_size = sizeof(typename props::Scalar);
|
constexpr size_t elem_size = sizeof(typename props::Scalar);
|
||||||
std::vector<size_t> shape, strides;
|
array a;
|
||||||
if (props::vector) {
|
if (props::vector)
|
||||||
shape.push_back(src.size());
|
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
|
||||||
strides.push_back(elem_size * src.innerStride());
|
else
|
||||||
}
|
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
|
||||||
else {
|
src.data(), base);
|
||||||
shape.push_back(src.rows());
|
|
||||||
shape.push_back(src.cols());
|
|
||||||
strides.push_back(elem_size * src.rowStride());
|
|
||||||
strides.push_back(elem_size * src.colStride());
|
|
||||||
}
|
|
||||||
array a(std::move(shape), std::move(strides), src.data(), base);
|
|
||||||
if (!writeable)
|
if (!writeable)
|
||||||
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
||||||
|
|
||||||
|
|||||||
@ -455,12 +455,18 @@ public:
|
|||||||
|
|
||||||
array() : array(0, static_cast<const double *>(nullptr)) {}
|
array() : array(0, static_cast<const double *>(nullptr)) {}
|
||||||
|
|
||||||
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
|
using ShapeContainer = detail::any_container<Py_intptr_t>;
|
||||||
const std::vector<size_t> &strides, const void *ptr = nullptr,
|
using StridesContainer = detail::any_container<Py_intptr_t>;
|
||||||
handle base = handle()) {
|
|
||||||
auto& api = detail::npy_api::get();
|
// Constructs an array taking shape/strides from arbitrary container types
|
||||||
auto ndim = shape.size();
|
array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
|
||||||
if (shape.size() != strides.size())
|
const void *ptr = nullptr, handle base = handle()) {
|
||||||
|
|
||||||
|
if (strides->empty())
|
||||||
|
strides = default_strides(*shape, dt.itemsize());
|
||||||
|
|
||||||
|
auto ndim = shape->size();
|
||||||
|
if (ndim != strides->size())
|
||||||
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
|
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
|
||||||
auto descr = dt;
|
auto descr = dt;
|
||||||
|
|
||||||
@ -474,10 +480,9 @@ public:
|
|||||||
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto &api = detail::npy_api::get();
|
||||||
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
|
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
|
||||||
api.PyArray_Type_, descr.release().ptr(), (int) ndim,
|
api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(),
|
||||||
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
|
|
||||||
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
|
|
||||||
const_cast<void *>(ptr), flags, nullptr));
|
const_cast<void *>(ptr), flags, nullptr));
|
||||||
if (!tmp)
|
if (!tmp)
|
||||||
pybind11_fail("NumPy: unable to create array!");
|
pybind11_fail("NumPy: unable to create array!");
|
||||||
@ -491,27 +496,24 @@ public:
|
|||||||
m_ptr = tmp.release().ptr();
|
m_ptr = tmp.release().ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
|
array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
|
||||||
const void *ptr = nullptr, handle base = handle())
|
: array(dt, std::move(shape), {}, ptr, base) { }
|
||||||
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
|
|
||||||
|
|
||||||
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
|
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
|
||||||
handle base = handle())
|
handle base = handle())
|
||||||
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
|
: array(dt, ShapeContainer{{ count }}, ptr, base) { }
|
||||||
|
|
||||||
template<typename T> array(const std::vector<size_t>& shape,
|
|
||||||
const std::vector<size_t>& strides,
|
|
||||||
const T* ptr, handle base = handle())
|
|
||||||
: array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array(const std::vector<size_t> &shape, const T *ptr,
|
array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
|
||||||
handle base = handle())
|
: array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
|
||||||
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
|
|
||||||
|
template <typename T>
|
||||||
|
array(ShapeContainer shape, const T *ptr, handle base = handle())
|
||||||
|
: array(std::move(shape), {}, ptr, base) { }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array(size_t count, const T *ptr, handle base = handle())
|
array(size_t count, const T *ptr, handle base = handle())
|
||||||
: array(std::vector<size_t>{ count }, ptr, base) { }
|
: array({{ count }}, ptr, base) { }
|
||||||
|
|
||||||
explicit array(const buffer_info &info)
|
explicit array(const buffer_info &info)
|
||||||
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
|
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
|
||||||
@ -673,9 +675,9 @@ protected:
|
|||||||
throw std::domain_error("array is not writeable");
|
throw std::domain_error("array is not writeable");
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
|
static std::vector<Py_intptr_t> default_strides(const std::vector<Py_intptr_t>& shape, size_t itemsize) {
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
std::vector<size_t> strides(ndim);
|
std::vector<Py_intptr_t> strides(ndim);
|
||||||
if (ndim) {
|
if (ndim) {
|
||||||
std::fill(strides.begin(), strides.end(), itemsize);
|
std::fill(strides.begin(), strides.end(), itemsize);
|
||||||
for (size_t i = 0; i < ndim - 1; i++)
|
for (size_t i = 0; i < ndim - 1; i++)
|
||||||
@ -731,14 +733,11 @@ public:
|
|||||||
|
|
||||||
explicit array_t(const buffer_info& info) : array(info) { }
|
explicit array_t(const buffer_info& info) : array(info) { }
|
||||||
|
|
||||||
array_t(const std::vector<size_t> &shape,
|
array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
|
||||||
const std::vector<size_t> &strides, const T *ptr = nullptr,
|
: array(std::move(shape), std::move(strides), ptr, base) { }
|
||||||
handle base = handle())
|
|
||||||
: array(shape, strides, ptr, base) { }
|
|
||||||
|
|
||||||
explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
|
explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
|
||||||
handle base = handle())
|
: array(std::move(shape), ptr, base) { }
|
||||||
: array(shape, ptr, base) { }
|
|
||||||
|
|
||||||
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
|
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
|
||||||
: array(count, ptr, base) { }
|
: array(count, ptr, base) { }
|
||||||
|
|||||||
@ -13,7 +13,6 @@
|
|||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
using arr = py::array;
|
using arr = py::array;
|
||||||
using arr_t = py::array_t<uint16_t, 0>;
|
using arr_t = py::array_t<uint16_t, 0>;
|
||||||
@ -119,8 +118,8 @@ test_initializer numpy_array([](py::module &m) {
|
|||||||
sm.def("wrap", [](py::array a) {
|
sm.def("wrap", [](py::array a) {
|
||||||
return py::array(
|
return py::array(
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
std::vector<size_t>(a.shape(), a.shape() + a.ndim()),
|
{a.shape(), a.shape() + a.ndim()},
|
||||||
std::vector<size_t>(a.strides(), a.strides() + a.ndim()),
|
{a.strides(), a.strides() + a.ndim()},
|
||||||
a.data(),
|
a.data(),
|
||||||
a
|
a
|
||||||
);
|
);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user