From f86dddf7ba45810c97dc2cf8e8cb6ee1e8477307 Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Mon, 16 Jan 2017 20:22:00 -0500 Subject: [PATCH] array: fix base handling numpy arrays aren't currently properly setting base: by setting `->base` directly, the base doesn't follow what numpy expects and documents (that is, following chained array bases to the root array). This fixes the behaviour by using numpy's PyArray_SetBaseObject to set the base instead, and then updates the tests to reflect the fixed behaviour. --- include/pybind11/numpy.h | 11 +++++++---- tests/test_numpy_array.py | 18 ++++++++++-------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 7a79aa88..f2c4d9ce 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -154,6 +154,7 @@ struct npy_api { int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, Py_ssize_t *, PyObject **, PyObject *); PyObject *(*PyArray_Squeeze_)(PyObject *); + int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); private: enum functions { API_PyArray_Type = 2, @@ -168,7 +169,8 @@ private: API_PyArray_DescrConverter = 174, API_PyArray_EquivTypes = 182, API_PyArray_GetArrayParamsFromObject = 278, - API_PyArray_Squeeze = 136 + API_PyArray_Squeeze = 136, + API_PyArray_SetBaseObject = 282 }; static npy_api lookup() { @@ -194,6 +196,7 @@ private: DECL_NPY_API(PyArray_EquivTypes); DECL_NPY_API(PyArray_GetArrayParamsFromObject); DECL_NPY_API(PyArray_Squeeze); + DECL_NPY_API(PyArray_SetBaseObject); #undef DECL_NPY_API return api; } @@ -365,7 +368,7 @@ public: pybind11_fail("NumPy: unable to create array!"); if (ptr) { if (base) { - detail::array_proxy(tmp.ptr())->base = base.inc_ref().ptr(); + api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr()); } else { tmp = reinterpret_steal(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */)); } @@ -632,8 +635,8 @@ public: return *(static_cast(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize()); } - /// Ensure that the argument is a NumPy array of the correct dtype. - /// In case of an error, nullptr is returned and the Python error is cleared. + /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert + /// it). In case of an error, nullptr is returned and the Python error is cleared. static array_t ensure(handle h) { auto result = reinterpret_steal(raw_array_t(h.ptr())); if (!result) diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 25defe7d..0094d761 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -165,7 +165,9 @@ def test_make_c_f_array(): def test_wrap(): from pybind11_tests.array import wrap - def assert_references(a, b): + def assert_references(a, b, base=None): + if base is None: + base = a assert a is not b assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0] assert a.shape == b.shape @@ -177,7 +179,7 @@ def test_wrap(): assert a.flags.updateifcopy == b.flags.updateifcopy assert np.all(a == b) assert not b.flags.owndata - assert b.base is a + assert b.base is base if a.flags.writeable and a.ndim == 2: a[0, 0] = 1234 assert b[0, 0] == 1234 @@ -201,13 +203,13 @@ def test_wrap(): a2 = wrap(a1) assert_references(a1, a2) - a1 = a1.transpose() - a2 = wrap(a1) - assert_references(a1, a2) + a1t = a1.transpose() + a2 = wrap(a1t) + assert_references(a1t, a2, a1) - a1 = a1.diagonal() - a2 = wrap(a1) - assert_references(a1, a2) + a1d = a1.diagonal() + a2 = wrap(a1d) + assert_references(a1d, a2, a1) @pytest.requires_numpy