diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 8e8abe3d..19965ba8 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -13,6 +13,7 @@ #include "complex.h" #include #include +#include #if defined(_MSC_VER) #pragma warning(push) @@ -31,6 +32,7 @@ public: API_PyArray_FromAny = 69, API_PyArray_NewCopy = 85, API_PyArray_NewFromDescr = 94, + API_PyArray_GetArrayParamsFromObject = 278, NPY_C_CONTIGUOUS_ = 0x0001, NPY_F_CONTIGUOUS_ = 0x0002, @@ -61,6 +63,7 @@ public: DECL_NPY_API(PyArray_FromAny); DECL_NPY_API(PyArray_NewCopy); DECL_NPY_API(PyArray_NewFromDescr); + DECL_NPY_API(PyArray_GetArrayParamsFromObject); #undef DECL_NPY_API return api; } @@ -74,6 +77,8 @@ public: PyObject *(*PyArray_NewCopy_)(PyObject *, int); PyTypeObject *PyArray_Type_; PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); + int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, + Py_ssize_t *, PyObject **, PyObject *); }; PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_) @@ -100,24 +105,22 @@ public: } array(const buffer_info &info) { - API& api = lookup_api(); - if ((info.format.size() < 1) || (info.format.size() > 2)) - pybind11_fail("Unsupported buffer format!"); - int fmt = (int) info.format[0]; - if (info.format == "Zd") fmt = API::NPY_CDOUBLE_; - else if (info.format == "Zf") fmt = API::NPY_CFLOAT_; + PyObject *arr = nullptr, *descr = nullptr; + int ndim = 0; + Py_ssize_t dims[32]; - PyObject *descr = api.PyArray_DescrFromType_(fmt); - if (descr == nullptr) - pybind11_fail("NumPy: unsupported buffer format '" + info.format + "'!"); - object tmp(api.PyArray_NewFromDescr_( - api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0], - (Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false); - if (info.ptr && tmp) - tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false); - if (!tmp) - pybind11_fail("NumPy: unable to create array!"); - m_ptr = tmp.release().ptr(); + // allocate zeroed memory if it hasn't been provided + auto buf_info = info; + if (!buf_info.ptr) + buf_info.ptr = std::calloc(info.size, info.itemsize); + auto view = py::memoryview(buf_info); + + API& api = lookup_api(); + auto res = api.PyArray_GetArrayParamsFromObject_(view.ptr(), nullptr, 1, &descr, + &ndim, dims, &arr, nullptr); + if (res < 0 || !arr || descr) + pybind11_fail("NumPy: unable to convert buffer to an array"); + m_ptr = arr; } protected: