mirror of
https://github.com/RYDE-WORK/pybind11.git
synced 2026-01-30 19:23:12 +08:00
Improve consistency of array and array_t with regard to other pytypes
* `array_t(const object &)` now throws on error * `array_t::ensure()` is intended for casters —- old constructor is deprecated * `array` and `array_t` get default constructors (empty array) * `array` gets a converting constructor * `py::isinstance<array_T<T>>()` checks the type (but not flags) There is only one special thing which must remain: `array_t` gets its own `type_caster` specialization which uses `ensure` instead of a simple check.
This commit is contained in:
parent
c7ac16bb2e
commit
4de271027d
@ -54,7 +54,7 @@ struct type_caster<Type, enable_if_t<is_eigen_dense<Type>::value && !is_eigen_re
|
|||||||
static constexpr bool isVector = Type::IsVectorAtCompileTime;
|
static constexpr bool isVector = Type::IsVectorAtCompileTime;
|
||||||
|
|
||||||
bool load(handle src, bool) {
|
bool load(handle src, bool) {
|
||||||
array_t<Scalar> buf(src, true);
|
auto buf = array_t<Scalar>::ensure(src);
|
||||||
if (!buf)
|
if (!buf)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
|||||||
@ -305,7 +305,7 @@ private:
|
|||||||
|
|
||||||
class array : public buffer {
|
class array : public buffer {
|
||||||
public:
|
public:
|
||||||
PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)
|
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
|
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
|
||||||
@ -313,6 +313,8 @@ public:
|
|||||||
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
|
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
|
||||||
};
|
};
|
||||||
|
|
||||||
|
array() : array(0, static_cast<const double *>(nullptr)) {}
|
||||||
|
|
||||||
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
|
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
|
||||||
const std::vector<size_t> &strides, const void *ptr = nullptr,
|
const std::vector<size_t> &strides, const void *ptr = nullptr,
|
||||||
handle base = handle()) {
|
handle base = handle()) {
|
||||||
@ -478,10 +480,12 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Ensure that the argument is a NumPy array
|
/// Ensure that the argument is a NumPy array
|
||||||
static array ensure(object input, int ExtraFlags = 0) {
|
/// In case of an error, nullptr is returned and the Python error is cleared.
|
||||||
auto& api = detail::npy_api::get();
|
static array ensure(handle h, int ExtraFlags = 0) {
|
||||||
return reinterpret_steal<array>(api.PyArray_FromAny_(
|
auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
|
||||||
input.release().ptr(), nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr));
|
if (!result)
|
||||||
|
PyErr_Clear();
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -520,8 +524,6 @@ protected:
|
|||||||
return strides;
|
return strides;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
|
||||||
|
|
||||||
template<typename... Ix> void check_dimensions(Ix... index) const {
|
template<typename... Ix> void check_dimensions(Ix... index) const {
|
||||||
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
|
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
|
||||||
}
|
}
|
||||||
@ -536,15 +538,31 @@ protected:
|
|||||||
}
|
}
|
||||||
check_dimensions_impl(axis + 1, shape + 1, index...);
|
check_dimensions_impl(axis + 1, shape + 1, index...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create array from any object -- always returns a new reference
|
||||||
|
static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
|
||||||
|
if (ptr == nullptr)
|
||||||
|
return nullptr;
|
||||||
|
return detail::npy_api::get().PyArray_FromAny_(
|
||||||
|
ptr, nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||||
public:
|
public:
|
||||||
array_t() : array() { }
|
array_t() : array(0, static_cast<const T *>(nullptr)) {}
|
||||||
|
array_t(handle h, borrowed_t) : array(h, borrowed) { }
|
||||||
|
array_t(handle h, stolen_t) : array(h, stolen) { }
|
||||||
|
|
||||||
array_t(handle h, bool is_borrowed) : array(h, is_borrowed) { m_ptr = ensure_(m_ptr); }
|
PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
|
||||||
|
array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
|
||||||
|
if (!m_ptr) PyErr_Clear();
|
||||||
|
if (!is_borrowed) Py_XDECREF(h.ptr());
|
||||||
|
}
|
||||||
|
|
||||||
array_t(const object &o) : array(o) { m_ptr = ensure_(m_ptr); }
|
array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
|
||||||
|
if (!m_ptr) throw error_already_set();
|
||||||
|
}
|
||||||
|
|
||||||
explicit array_t(const buffer_info& info) : array(info) { }
|
explicit array_t(const buffer_info& info) : array(info) { }
|
||||||
|
|
||||||
@ -590,17 +608,30 @@ public:
|
|||||||
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject *ensure_(PyObject *ptr) {
|
/// Ensure that the argument is a NumPy array of the correct dtype.
|
||||||
if (ptr == nullptr)
|
/// In case of an error, nullptr is returned and the Python error is cleared.
|
||||||
return nullptr;
|
static array_t ensure(handle h) {
|
||||||
auto& api = detail::npy_api::get();
|
auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
|
||||||
PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
|
|
||||||
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
|
||||||
if (!result)
|
if (!result)
|
||||||
PyErr_Clear();
|
PyErr_Clear();
|
||||||
Py_DECREF(ptr);
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool _check(handle h) {
|
||||||
|
const auto &api = detail::npy_api::get();
|
||||||
|
return api.PyArray_Check_(h.ptr())
|
||||||
|
&& api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of<T>().ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// Create array from any object -- always returns a new reference
|
||||||
|
static PyObject *raw_array_t(PyObject *ptr) {
|
||||||
|
if (ptr == nullptr)
|
||||||
|
return nullptr;
|
||||||
|
return detail::npy_api::get().PyArray_FromAny_(
|
||||||
|
ptr, dtype::of<T>().release().ptr(), 0, 0,
|
||||||
|
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -631,7 +662,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
|
|||||||
using type = array_t<T, ExtraFlags>;
|
using type = array_t<T, ExtraFlags>;
|
||||||
|
|
||||||
bool load(handle src, bool /* convert */) {
|
bool load(handle src, bool /* convert */) {
|
||||||
value = type(src, true);
|
value = type::ensure(src);
|
||||||
return static_cast<bool>(value);
|
return static_cast<bool>(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -126,4 +126,28 @@ test_initializer numpy_array([](py::module &m) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
sm.def("function_taking_uint64", [](uint64_t) { });
|
sm.def("function_taking_uint64", [](uint64_t) { });
|
||||||
|
|
||||||
|
sm.def("isinstance_untyped", [](py::object yes, py::object no) {
|
||||||
|
return py::isinstance<py::array>(yes) && !py::isinstance<py::array>(no);
|
||||||
|
});
|
||||||
|
|
||||||
|
sm.def("isinstance_typed", [](py::object o) {
|
||||||
|
return py::isinstance<py::array_t<double>>(o) && !py::isinstance<py::array_t<int>>(o);
|
||||||
|
});
|
||||||
|
|
||||||
|
sm.def("default_constructors", []() {
|
||||||
|
return py::dict(
|
||||||
|
"array"_a=py::array(),
|
||||||
|
"array_t<int32>"_a=py::array_t<std::int32_t>(),
|
||||||
|
"array_t<double>"_a=py::array_t<double>()
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
sm.def("converting_constructors", [](py::object o) {
|
||||||
|
return py::dict(
|
||||||
|
"array"_a=py::array(o),
|
||||||
|
"array_t<int32>"_a=py::array_t<std::int32_t>(o),
|
||||||
|
"array_t<double>"_a=py::array_t<double>(o)
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -245,3 +245,30 @@ def test_cast_numpy_int64_to_uint64():
|
|||||||
from pybind11_tests.array import function_taking_uint64
|
from pybind11_tests.array import function_taking_uint64
|
||||||
function_taking_uint64(123)
|
function_taking_uint64(123)
|
||||||
function_taking_uint64(np.uint64(123))
|
function_taking_uint64(np.uint64(123))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.requires_numpy
|
||||||
|
def test_isinstance():
|
||||||
|
from pybind11_tests.array import isinstance_untyped, isinstance_typed
|
||||||
|
|
||||||
|
assert isinstance_untyped(np.array([1, 2, 3]), "not an array")
|
||||||
|
assert isinstance_typed(np.array([1.0, 2.0, 3.0]))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.requires_numpy
|
||||||
|
def test_constructors():
|
||||||
|
from pybind11_tests.array import default_constructors, converting_constructors
|
||||||
|
|
||||||
|
defaults = default_constructors()
|
||||||
|
for a in defaults.values():
|
||||||
|
assert a.size == 0
|
||||||
|
assert defaults["array"].dtype == np.array([]).dtype
|
||||||
|
assert defaults["array_t<int32>"].dtype == np.int32
|
||||||
|
assert defaults["array_t<double>"].dtype == np.float64
|
||||||
|
|
||||||
|
results = converting_constructors([1, 2, 3])
|
||||||
|
for a in results.values():
|
||||||
|
np.testing.assert_array_equal(a, [1, 2, 3])
|
||||||
|
assert results["array"].dtype == np.int_
|
||||||
|
assert results["array_t<int32>"].dtype == np.int32
|
||||||
|
assert results["array_t<double>"].dtype == np.float64
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user