From ba7a0fac739554a4e24d2c4bcb836922a50314a0 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 14 Apr 2022 09:53:16 -0500 Subject: [PATCH] Expand dtype accessors (#3868) * Added constructor based on typenum, based on PyArray_DescrFromType Added accessors for typenum, alignment, byteorder and flags fields of PyArray_Descr struct. * Added tests for new py::dtype constructor, and for accessors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed the comment for alignment method * Update include/pybind11/numpy.h Co-authored-by: Aaron Gokaslan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aaron Gokaslan --- include/pybind11/numpy.h | 24 ++++++++++++++++++++++++ tests/test_numpy_dtypes.cpp | 29 +++++++++++++++++++++++++++++ tests/test_numpy_dtypes.py | 8 +++++++- 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index d45fe428..d71f0890 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -562,6 +562,13 @@ public: m_ptr = from_args(std::move(args)).release().ptr(); } + explicit dtype(int typenum) + : object(detail::npy_api::get().PyArray_DescrFromType_(typenum), stolen_t{}) { + if (m_ptr == nullptr) { + throw error_already_set(); + } + } + /// This is essentially the same as calling numpy.dtype(args) in Python. static dtype from_args(object args) { PyObject *ptr = nullptr; @@ -596,6 +603,23 @@ public: return detail::array_descriptor_proxy(m_ptr)->type; } + /// type number of dtype. + ssize_t num() const { + // Note: The signature, `dtype::num` follows the naming of NumPy's public + // Python API (i.e., ``dtype.num``), rather than its internal + // C API (``PyArray_Descr::type_num``). + return detail::array_descriptor_proxy(m_ptr)->type_num; + } + + /// Single character for byteorder + char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; } + + /// Alignment of the data type + int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; } + + /// Flags for the array descriptor + char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; } + private: static object _dtype_from_pep3118() { static PyObject *obj = module_::import("numpy.core._internal") diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index dd5b123d..7de36f2f 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -291,6 +291,7 @@ py::list test_dtype_ctors() { list.append(py::dtype(names, formats, offsets, 20)); list.append(py::dtype(py::buffer_info((void *) 0, sizeof(unsigned int), "I", 1))); list.append(py::dtype(py::buffer_info((void *) 0, 0, "T{i:a:f:b:}", 1))); + list.append(py::dtype(py::detail::npy_api::NPY_DOUBLE_)); return list; } @@ -440,6 +441,34 @@ TEST_SUBMODULE(numpy_dtypes, m) { } return list; }); + m.def("test_dtype_num", [dtype_names]() { + py::list list; + for (const auto &dt_name : dtype_names) { + list.append(py::dtype(dt_name).num()); + } + return list; + }); + m.def("test_dtype_byteorder", [dtype_names]() { + py::list list; + for (const auto &dt_name : dtype_names) { + list.append(py::dtype(dt_name).byteorder()); + } + return list; + }); + m.def("test_dtype_alignment", [dtype_names]() { + py::list list; + for (const auto &dt_name : dtype_names) { + list.append(py::dtype(dt_name).alignment()); + } + return list; + }); + m.def("test_dtype_flags", [dtype_names]() { + py::list list; + for (const auto &dt_name : dtype_names) { + list.append(py::dtype(dt_name).flags()); + } + return list; + }); m.def("test_dtype_methods", []() { py::list list; auto dt1 = py::dtype::of(); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 7df60583..fcfd587b 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -160,6 +160,7 @@ def test_dtype(simple_dtype): d1, np.dtype("uint32"), d2, + np.dtype("d"), ] assert m.test_dtype_methods() == [ @@ -175,8 +176,13 @@ def test_dtype(simple_dtype): np.zeros(1, m.trailing_padding_dtype()) ) + expected_chars = "bhilqBHILQefdgFDG?MmO" assert m.test_dtype_kind() == list("iiiiiuuuuuffffcccbMmO") - assert m.test_dtype_char_() == list("bhilqBHILQefdgFDG?MmO") + assert m.test_dtype_char_() == list(expected_chars) + assert m.test_dtype_num() == [np.dtype(ch).num for ch in expected_chars] + assert m.test_dtype_byteorder() == [np.dtype(ch).byteorder for ch in expected_chars] + assert m.test_dtype_alignment() == [np.dtype(ch).alignment for ch in expected_chars] + assert m.test_dtype_flags() == [chr(np.dtype(ch).flags) for ch in expected_chars] def test_recarray(simple_dtype, packed_dtype):