diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 19965ba8..49489ceb 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -91,9 +91,7 @@ public: template array(size_t size, const Type *ptr) { API& api = lookup_api(); - PyObject *descr = api.PyArray_DescrFromType_(detail::npy_format_descriptor::typenum()); - if (descr == nullptr) - pybind11_fail("NumPy: unsupported buffer format!"); + PyObject *descr = detail::npy_format_descriptor::descr(); Py_intptr_t shape = (Py_intptr_t) size; object tmp = object(api.PyArray_NewFromDescr_( api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false); @@ -128,6 +126,8 @@ protected: static API api = API::lookup(); return api; } + + template friend struct detail::npy_format_descriptor; }; template class array_t : public array { @@ -140,7 +140,7 @@ public: if (ptr == nullptr) return nullptr; API &api = lookup_api(); - PyObject *descr = api.PyArray_DescrFromType_(detail::npy_format_descriptor::typenum()); + PyObject *descr = detail::npy_format_descriptor::descr(); PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); if (!result) PyErr_Clear(); @@ -158,6 +158,10 @@ private: array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ }; public: static int typenum() { return values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned::value ? 1 : 0)]; } + static PyObject* descr() { + if (auto obj = array::lookup_api().PyArray_DescrFromType_(typenum())) return obj; + else pybind11_fail("Unsupported buffer format!"); + } template ::value, int>::type = 0> static PYBIND11_DESCR name() { return _("int") + _(); } template ::value, int>::type = 0> @@ -167,7 +171,11 @@ template constexpr const int npy_format_descriptor< T, typename std::enable_if::value>::type>::values[8]; #define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor { \ - static int typenum() { return array::API::NumPyName; } \ + static int typenum() { return array::API::NumPyName; } \ + static PyObject* descr() { \ + if (auto obj = array::lookup_api().PyArray_DescrFromType_(typenum())) return obj; \ + else pybind11_fail("Unsupported buffer format!"); \ + } \ static PYBIND11_DESCR name() { return _(Name); } } DECL_FMT(float, NPY_FLOAT_, "float32"); DECL_FMT(double, NPY_DOUBLE_, "float64");