array: fix base handling

numpy arrays aren't currently properly setting base: by setting `->base`
directly, the base doesn't follow what numpy expects and documents (that
is, following chained array bases to the root array).

This fixes the behaviour by using numpy's PyArray_SetBaseObject to set
the base instead, and then updates the tests to reflect the fixed
behaviour.
This commit is contained in:
Jason Rhinelander 2017-01-16 20:22:00 -05:00 committed by Wenzel Jakob
parent 88fff9d189
commit f86dddf7ba
2 changed files with 17 additions and 12 deletions

View File

@ -154,6 +154,7 @@ struct npy_api {
int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
Py_ssize_t *, PyObject **, PyObject *); Py_ssize_t *, PyObject **, PyObject *);
PyObject *(*PyArray_Squeeze_)(PyObject *); PyObject *(*PyArray_Squeeze_)(PyObject *);
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
private: private:
enum functions { enum functions {
API_PyArray_Type = 2, API_PyArray_Type = 2,
@ -168,7 +169,8 @@ private:
API_PyArray_DescrConverter = 174, API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182, API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278, API_PyArray_GetArrayParamsFromObject = 278,
API_PyArray_Squeeze = 136 API_PyArray_Squeeze = 136,
API_PyArray_SetBaseObject = 282
}; };
static npy_api lookup() { static npy_api lookup() {
@ -194,6 +196,7 @@ private:
DECL_NPY_API(PyArray_EquivTypes); DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject); DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze); DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_SetBaseObject);
#undef DECL_NPY_API #undef DECL_NPY_API
return api; return api;
} }
@ -365,7 +368,7 @@ public:
pybind11_fail("NumPy: unable to create array!"); pybind11_fail("NumPy: unable to create array!");
if (ptr) { if (ptr) {
if (base) { if (base) {
detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr(); api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
} else { } else {
tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */)); tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
} }
@ -632,8 +635,8 @@ public:
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize()); return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
} }
/// Ensure that the argument is a NumPy array of the correct dtype. /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
/// In case of an error, nullptr is returned and the Python error is cleared. /// it). In case of an error, nullptr is returned and the Python error is cleared.
static array_t ensure(handle h) { static array_t ensure(handle h) {
auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr())); auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
if (!result) if (!result)

View File

@ -165,7 +165,9 @@ def test_make_c_f_array():
def test_wrap(): def test_wrap():
from pybind11_tests.array import wrap from pybind11_tests.array import wrap
def assert_references(a, b): def assert_references(a, b, base=None):
if base is None:
base = a
assert a is not b assert a is not b
assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0] assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0]
assert a.shape == b.shape assert a.shape == b.shape
@ -177,7 +179,7 @@ def test_wrap():
assert a.flags.updateifcopy == b.flags.updateifcopy assert a.flags.updateifcopy == b.flags.updateifcopy
assert np.all(a == b) assert np.all(a == b)
assert not b.flags.owndata assert not b.flags.owndata
assert b.base is a assert b.base is base
if a.flags.writeable and a.ndim == 2: if a.flags.writeable and a.ndim == 2:
a[0, 0] = 1234 a[0, 0] = 1234
assert b[0, 0] == 1234 assert b[0, 0] == 1234
@ -201,13 +203,13 @@ def test_wrap():
a2 = wrap(a1) a2 = wrap(a1)
assert_references(a1, a2) assert_references(a1, a2)
a1 = a1.transpose() a1t = a1.transpose()
a2 = wrap(a1) a2 = wrap(a1t)
assert_references(a1, a2) assert_references(a1t, a2, a1)
a1 = a1.diagonal() a1d = a1.diagonal()
a2 = wrap(a1) a2 = wrap(a1d)
assert_references(a1, a2) assert_references(a1d, a2, a1)
@pytest.requires_numpy @pytest.requires_numpy