mirror of
https://github.com/RYDE-WORK/pybind11.git
synced 2026-02-04 14:33:21 +08:00
Use memoryview for constructing array from buffer
This commit is contained in:
parent
ea2755ccdc
commit
a67c2b52e4
@ -13,6 +13,7 @@
|
|||||||
#include "complex.h"
|
#include "complex.h"
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(push)
|
#pragma warning(push)
|
||||||
@ -31,6 +32,7 @@ public:
|
|||||||
API_PyArray_FromAny = 69,
|
API_PyArray_FromAny = 69,
|
||||||
API_PyArray_NewCopy = 85,
|
API_PyArray_NewCopy = 85,
|
||||||
API_PyArray_NewFromDescr = 94,
|
API_PyArray_NewFromDescr = 94,
|
||||||
|
API_PyArray_GetArrayParamsFromObject = 278,
|
||||||
|
|
||||||
NPY_C_CONTIGUOUS_ = 0x0001,
|
NPY_C_CONTIGUOUS_ = 0x0001,
|
||||||
NPY_F_CONTIGUOUS_ = 0x0002,
|
NPY_F_CONTIGUOUS_ = 0x0002,
|
||||||
@ -61,6 +63,7 @@ public:
|
|||||||
DECL_NPY_API(PyArray_FromAny);
|
DECL_NPY_API(PyArray_FromAny);
|
||||||
DECL_NPY_API(PyArray_NewCopy);
|
DECL_NPY_API(PyArray_NewCopy);
|
||||||
DECL_NPY_API(PyArray_NewFromDescr);
|
DECL_NPY_API(PyArray_NewFromDescr);
|
||||||
|
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||||
#undef DECL_NPY_API
|
#undef DECL_NPY_API
|
||||||
return api;
|
return api;
|
||||||
}
|
}
|
||||||
@ -74,6 +77,8 @@ public:
|
|||||||
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
|
||||||
PyTypeObject *PyArray_Type_;
|
PyTypeObject *PyArray_Type_;
|
||||||
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
|
||||||
|
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
||||||
|
Py_ssize_t *, PyObject **, PyObject *);
|
||||||
};
|
};
|
||||||
|
|
||||||
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
|
PYBIND11_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check_)
|
||||||
@ -100,24 +105,22 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
array(const buffer_info &info) {
|
array(const buffer_info &info) {
|
||||||
API& api = lookup_api();
|
PyObject *arr = nullptr, *descr = nullptr;
|
||||||
if ((info.format.size() < 1) || (info.format.size() > 2))
|
int ndim = 0;
|
||||||
pybind11_fail("Unsupported buffer format!");
|
Py_ssize_t dims[32];
|
||||||
int fmt = (int) info.format[0];
|
|
||||||
if (info.format == "Zd") fmt = API::NPY_CDOUBLE_;
|
|
||||||
else if (info.format == "Zf") fmt = API::NPY_CFLOAT_;
|
|
||||||
|
|
||||||
PyObject *descr = api.PyArray_DescrFromType_(fmt);
|
// allocate zeroed memory if it hasn't been provided
|
||||||
if (descr == nullptr)
|
auto buf_info = info;
|
||||||
pybind11_fail("NumPy: unsupported buffer format '" + info.format + "'!");
|
if (!buf_info.ptr)
|
||||||
object tmp(api.PyArray_NewFromDescr_(
|
buf_info.ptr = std::calloc(info.size, info.itemsize);
|
||||||
api.PyArray_Type_, descr, (int) info.ndim, (Py_intptr_t *) &info.shape[0],
|
auto view = py::memoryview(buf_info);
|
||||||
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr), false);
|
|
||||||
if (info.ptr && tmp)
|
API& api = lookup_api();
|
||||||
tmp = object(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */), false);
|
auto res = api.PyArray_GetArrayParamsFromObject_(view.ptr(), nullptr, 1, &descr,
|
||||||
if (!tmp)
|
&ndim, dims, &arr, nullptr);
|
||||||
pybind11_fail("NumPy: unable to create array!");
|
if (res < 0 || !arr || descr)
|
||||||
m_ptr = tmp.release().ptr();
|
pybind11_fail("NumPy: unable to convert buffer to an array");
|
||||||
|
m_ptr = arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user