Make npy_format_descriptor backwards-compat

The typenum for non-structured types is still accessible at ::value,
and the dtype object for all types is accessible at ::dtype().
This commit is contained in:
Ivan Smirnov 2016-06-26 16:19:18 +01:00
parent 95e9b12322
commit 40eadfeb73

View File

@ -96,7 +96,7 @@ public:
template <typename Type> array(size_t size, const Type *ptr) { template <typename Type> array(size_t size, const Type *ptr) {
API& api = lookup_api(); API& api = lookup_api();
PyObject *descr = detail::npy_format_descriptor<Type>::descr().release().ptr(); PyObject *descr = detail::npy_format_descriptor<Type>::dtype().release().ptr();
Py_intptr_t shape = (Py_intptr_t) size; Py_intptr_t shape = (Py_intptr_t) size;
object tmp = object(api.PyArray_NewFromDescr_( object tmp = object(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false); api.PyArray_Type_, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr), false);
@ -147,7 +147,7 @@ public:
if (ptr == nullptr) if (ptr == nullptr)
return nullptr; return nullptr;
API &api = lookup_api(); API &api = lookup_api();
PyObject *descr = detail::npy_format_descriptor<T>::descr().release().ptr(); PyObject *descr = detail::npy_format_descriptor<T>::dtype().release().ptr();
PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr); PyObject *result = api.PyArray_FromAny_(ptr, descr, 0, 0, API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
if (!result) if (!result)
PyErr_Clear(); PyErr_Clear();
@ -171,7 +171,7 @@ template <typename T> struct format_descriptor
template <typename T> template <typename T>
object dtype_of() { object dtype_of() {
return detail::npy_format_descriptor<T>::descr(); return detail::npy_format_descriptor<T>::dtype();
} }
NAMESPACE_BEGIN(detail) NAMESPACE_BEGIN(detail)
@ -182,10 +182,11 @@ private:
array::API::NPY_BYTE_, array::API::NPY_UBYTE_, array::API::NPY_SHORT_, array::API::NPY_USHORT_, array::API::NPY_BYTE_, array::API::NPY_UBYTE_, array::API::NPY_SHORT_, array::API::NPY_USHORT_,
array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ }; array::API::NPY_INT_, array::API::NPY_UINT_, array::API::NPY_LONGLONG_, array::API::NPY_ULONGLONG_ };
public: public:
static int typenum() { return values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)]; } enum { value = values[detail::log2(sizeof(T)) * 2 + (std::is_unsigned<T>::value ? 1 : 0)] };
static object descr() { static object dtype() {
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(typenum())) return object(ptr, true); if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value))
else pybind11_fail("Unsupported buffer format!"); return object(ptr, true);
pybind11_fail("Unsupported buffer format!");
} }
template <typename T2 = T, typename std::enable_if<std::is_signed<T2>::value, int>::type = 0> template <typename T2 = T, typename std::enable_if<std::is_signed<T2>::value, int>::type = 0>
static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); } static PYBIND11_DESCR name() { return _("int") + _<sizeof(T)*8>(); }
@ -196,10 +197,11 @@ template <typename T> constexpr const int npy_format_descriptor<
T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8]; T, typename std::enable_if<std::is_integral<T>::value>::type>::values[8];
#define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \ #define DECL_FMT(Type, NumPyName, Name) template<> struct npy_format_descriptor<Type> { \
static int typenum() { return array::API::NumPyName; } \ enum { value = array::API::NumPyName }; \
static object descr() { \ static object dtype() { \
if (auto ptr = array::lookup_api().PyArray_DescrFromType_(typenum())) return object(ptr, true); \ if (auto ptr = array::lookup_api().PyArray_DescrFromType_(value)) \
else pybind11_fail("Unsupported buffer format!"); \ return object(ptr, true); \
pybind11_fail("Unsupported buffer format!"); \
} \ } \
static PYBIND11_DESCR name() { return _(Name); } } static PYBIND11_DESCR name() { return _(Name); } }
DECL_FMT(float, NPY_FLOAT_, "float32"); DECL_FMT(float, NPY_FLOAT_, "float32");
@ -225,10 +227,10 @@ template <typename T> struct npy_format_descriptor
{ {
static PYBIND11_DESCR name() { return _("user-defined"); } static PYBIND11_DESCR name() { return _("user-defined"); }
static object descr() { static object dtype() {
if (!descr_()) if (!dtype_())
pybind11_fail("NumPy: unsupported buffer format!"); pybind11_fail("NumPy: unsupported buffer format!");
return object(descr_(), true); return object(dtype_(), true);
} }
static const char* format() { static const char* format() {
@ -249,11 +251,11 @@ template <typename T> struct npy_format_descriptor
args["names"] = names; args["names"] = names;
args["offsets"] = offsets; args["offsets"] = offsets;
args["formats"] = formats; args["formats"] = formats;
if (!api.PyArray_DescrConverter_(args.release().ptr(), &descr_()) || !descr_()) if (!api.PyArray_DescrConverter_(args.release().ptr(), &dtype_()) || !dtype_())
pybind11_fail("NumPy: failed to create structured dtype"); pybind11_fail("NumPy: failed to create structured dtype");
auto np = module::import("numpy"); auto np = module::import("numpy");
auto empty = (object) np.attr("empty"); auto empty = (object) np.attr("empty");
if (auto arr = (object) empty(int_(0), object(descr(), true))) if (auto arr = (object) empty(int_(0), dtype()))
if (auto view = PyMemoryView_FromObject(arr.ptr())) if (auto view = PyMemoryView_FromObject(arr.ptr()))
if (auto info = PyMemoryView_GET_BUFFER(view)) { if (auto info = PyMemoryView_GET_BUFFER(view)) {
std::strncpy(format_(), info->format, 4096); std::strncpy(format_(), info->format, 4096);
@ -263,14 +265,14 @@ template <typename T> struct npy_format_descriptor
} }
private: private:
static inline PyObject*& descr_() { static PyObject *ptr = nullptr; return ptr; } static inline PyObject*& dtype_() { static PyObject *ptr = nullptr; return ptr; }
static inline char* format_() { static char s[4096]; return s; } static inline char* format_() { static char s[4096]; return s; }
}; };
#define PB11_IMPL_FIELD_DESCRIPTOR(Type, Field) \ #define PB11_IMPL_FIELD_DESCRIPTOR(Type, Field) \
::pybind11::detail::field_descriptor { \ ::pybind11::detail::field_descriptor { \
#Field, offsetof(Type, Field), \ #Field, offsetof(Type, Field), \
::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::descr() \ ::pybind11::detail::npy_format_descriptor<decltype(static_cast<Type*>(0)->Field)>::dtype() \
} }
// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro // The main idea of this macro is borrowed from https://github.com/swansontec/map-macro