mirror of
https://github.com/RYDE-WORK/pybind11.git
synced 2026-02-05 06:53:36 +08:00
array: fix base handling
numpy arrays aren't currently properly setting base: by setting `->base` directly, the base doesn't follow what numpy expects and documents (that is, following chained array bases to the root array). This fixes the behaviour by using numpy's PyArray_SetBaseObject to set the base instead, and then updates the tests to reflect the fixed behaviour.
This commit is contained in:
parent
88fff9d189
commit
f86dddf7ba
@ -154,6 +154,7 @@ struct npy_api {
|
|||||||
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
|
||||||
Py_ssize_t *, PyObject **, PyObject *);
|
Py_ssize_t *, PyObject **, PyObject *);
|
||||||
PyObject *(*PyArray_Squeeze_)(PyObject *);
|
PyObject *(*PyArray_Squeeze_)(PyObject *);
|
||||||
|
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
|
||||||
private:
|
private:
|
||||||
enum functions {
|
enum functions {
|
||||||
API_PyArray_Type = 2,
|
API_PyArray_Type = 2,
|
||||||
@ -168,7 +169,8 @@ private:
|
|||||||
API_PyArray_DescrConverter = 174,
|
API_PyArray_DescrConverter = 174,
|
||||||
API_PyArray_EquivTypes = 182,
|
API_PyArray_EquivTypes = 182,
|
||||||
API_PyArray_GetArrayParamsFromObject = 278,
|
API_PyArray_GetArrayParamsFromObject = 278,
|
||||||
API_PyArray_Squeeze = 136
|
API_PyArray_Squeeze = 136,
|
||||||
|
API_PyArray_SetBaseObject = 282
|
||||||
};
|
};
|
||||||
|
|
||||||
static npy_api lookup() {
|
static npy_api lookup() {
|
||||||
@ -194,6 +196,7 @@ private:
|
|||||||
DECL_NPY_API(PyArray_EquivTypes);
|
DECL_NPY_API(PyArray_EquivTypes);
|
||||||
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
|
||||||
DECL_NPY_API(PyArray_Squeeze);
|
DECL_NPY_API(PyArray_Squeeze);
|
||||||
|
DECL_NPY_API(PyArray_SetBaseObject);
|
||||||
#undef DECL_NPY_API
|
#undef DECL_NPY_API
|
||||||
return api;
|
return api;
|
||||||
}
|
}
|
||||||
@ -365,7 +368,7 @@ public:
|
|||||||
pybind11_fail("NumPy: unable to create array!");
|
pybind11_fail("NumPy: unable to create array!");
|
||||||
if (ptr) {
|
if (ptr) {
|
||||||
if (base) {
|
if (base) {
|
||||||
detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr();
|
api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
|
||||||
} else {
|
} else {
|
||||||
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
|
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
|
||||||
}
|
}
|
||||||
@ -632,8 +635,8 @@ public:
|
|||||||
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Ensure that the argument is a NumPy array of the correct dtype.
|
/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
|
||||||
/// In case of an error, nullptr is returned and the Python error is cleared.
|
/// it). In case of an error, nullptr is returned and the Python error is cleared.
|
||||||
static array_t ensure(handle h) {
|
static array_t ensure(handle h) {
|
||||||
auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
|
auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
|
||||||
if (!result)
|
if (!result)
|
||||||
|
|||||||
@ -165,7 +165,9 @@ def test_make_c_f_array():
|
|||||||
def test_wrap():
|
def test_wrap():
|
||||||
from pybind11_tests.array import wrap
|
from pybind11_tests.array import wrap
|
||||||
|
|
||||||
def assert_references(a, b):
|
def assert_references(a, b, base=None):
|
||||||
|
if base is None:
|
||||||
|
base = a
|
||||||
assert a is not b
|
assert a is not b
|
||||||
assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0]
|
assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0]
|
||||||
assert a.shape == b.shape
|
assert a.shape == b.shape
|
||||||
@ -177,7 +179,7 @@ def test_wrap():
|
|||||||
assert a.flags.updateifcopy == b.flags.updateifcopy
|
assert a.flags.updateifcopy == b.flags.updateifcopy
|
||||||
assert np.all(a == b)
|
assert np.all(a == b)
|
||||||
assert not b.flags.owndata
|
assert not b.flags.owndata
|
||||||
assert b.base is a
|
assert b.base is base
|
||||||
if a.flags.writeable and a.ndim == 2:
|
if a.flags.writeable and a.ndim == 2:
|
||||||
a[0, 0] = 1234
|
a[0, 0] = 1234
|
||||||
assert b[0, 0] == 1234
|
assert b[0, 0] == 1234
|
||||||
@ -201,13 +203,13 @@ def test_wrap():
|
|||||||
a2 = wrap(a1)
|
a2 = wrap(a1)
|
||||||
assert_references(a1, a2)
|
assert_references(a1, a2)
|
||||||
|
|
||||||
a1 = a1.transpose()
|
a1t = a1.transpose()
|
||||||
a2 = wrap(a1)
|
a2 = wrap(a1t)
|
||||||
assert_references(a1, a2)
|
assert_references(a1t, a2, a1)
|
||||||
|
|
||||||
a1 = a1.diagonal()
|
a1d = a1.diagonal()
|
||||||
a2 = wrap(a1)
|
a2 = wrap(a1d)
|
||||||
assert_references(a1, a2)
|
assert_references(a1d, a2, a1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.requires_numpy
|
@pytest.requires_numpy
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user