Use memoryview for constructing array from buffer

This commit is contained in:
Ivan Smirnov 2016-06-19 14:50:06 +01:00
parent ea2755ccdc
commit a67c2b52e4

View File

@ -13,6 +13,7 @@
#include "complex.h" #include "complex.h"
#include <numeric> #include <numeric>
#include <algorithm> #include <algorithm>
#include <cstdlib>
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(push) #pragma warning(push)
@ -31,6 +32,7 @@ public:
API_PyArray_FromAny = 69, API_PyArray_FromAny = 69,
API_PyArray_NewCopy = 85, API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94, API_PyArray_NewFromDescr = 94,
API_PyArray_GetArrayParamsFromObject = 278,
NPY_C_CONTIGUOUS_ = 0x0001, NPY_C_CONTIGUOUS_ = 0x0001,
NPY_F_CONTIGUOUS_ = 0x0002, NPY_F_CONTIGUOUS_ = 0x0002,
@ -61,6 +63,7 @@ public:
DECL_NPY_API(PyArray_FromAny); DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_NewCopy); DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr); DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#undef DECL_NPY_API #undef DECL_NPY_API
return api; return api;
} }
@ -74,6 +77,8 @@ public:
PyObject *(*PyArray_NewCopy_)(PyObject *, int); PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_; PyTypeObject *PyArray_Type_;
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *);
}; };
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_) PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
@ -100,24 +105,22 @@ public:
} }
array(const buffer_info &info) { array(const buffer_info &info) {
API& api = lookup_api(); PyObject *arr = nullptr, *descr = nullptr;
if ((info.format.size() < 1) || (info.format.size() > 2)) int ndim = 0;
pybind11_fail("Unsupported buffer format!"); Py_ssize_t dims[32];
int fmt = (int) info.format[0];
if (info.format == "Zd") fmt = API::NPY_CDOUBLE_;
else if (info.format == "Zf") fmt = API::NPY_CFLOAT_;
PyObject *descr = api.PyArray_DescrFromType_(fmt); // allocate zeroed memory if it hasn't been provided
if (descr == nullptr) auto buf_info = info;
pybind11_fail("NumPy: unsupported buffer format '" + info.format + "'!"); if (!buf_info.ptr)
object tmp(api.PyArray_NewFromDescr_( buf_info.ptr = std::calloc(info.size, info.itemsize);
api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0], auto view = py::memoryview(buf_info);
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
if (info.ptr && tmp) API& api = lookup_api();
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false); auto res = api.PyArray_GetArrayParamsFromObject_(view.ptr(), nullptr, 1, &descr,
if (!tmp) &ndim, dims, &arr, nullptr);
pybind11_fail("NumPy: unable to create array!"); if (res < 0 || !arr || descr)
m_ptr = tmp.release().ptr(); pybind11_fail("NumPy: unable to convert buffer to an array");
m_ptr = arr;
} }
protected: protected: