From ba7678016cf340a3ee733ca3fe941667448eab67 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Fri, 7 Oct 2016 11:19:25 +0200 Subject: [PATCH] numpy.h: added array::squeeze() method --- include/pybind11/numpy.h | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 996bb7c6..1125fd71 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -109,6 +109,7 @@ struct npy_api { bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, Py_ssize_t *, PyObject **, PyObject *); + PyObject *(*PyArray_Squeeze_)(PyObject *); private: enum functions { API_PyArray_Type = 2, @@ -121,6 +122,7 @@ private: API_PyArray_DescrConverter = 174, API_PyArray_EquivTypes = 182, API_PyArray_GetArrayParamsFromObject = 278, + API_PyArray_Squeeze = 136 }; static npy_api lookup() { @@ -143,6 +145,7 @@ private: DECL_NPY_API(PyArray_DescrConverter); DECL_NPY_API(PyArray_EquivTypes); DECL_NPY_API(PyArray_GetArrayParamsFromObject); + DECL_NPY_API(PyArray_Squeeze); #undef DECL_NPY_API return api; } @@ -380,6 +383,12 @@ public: return offset_at(index...) / itemsize(); } + /// Return a new view with all of the dimensions of length 1 removed + array squeeze() { + auto& api = detail::npy_api::get(); + return array(api.PyArray_Squeeze_(m_ptr), false); + } + protected: template friend struct detail::npy_format_descriptor; @@ -601,7 +610,7 @@ struct npy_format_descriptor::value>> { // strings and will just do it ourselves. std::vector ordered_fields(fields); std::sort(ordered_fields.begin(), ordered_fields.end(), - [](const field_descriptor& a, const field_descriptor &b) { + [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); size_t offset = 0;