NumPy dtypes are now shared across extensions

This commit is contained in:
Ivan Smirnov 2016-10-31 13:52:32 +00:00
parent a743ead455
commit 2184f6d4d6
3 changed files with 54 additions and 27 deletions

View File

@ -323,6 +323,7 @@ struct internals {
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache; std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
std::unordered_map<std::type_index, std::vector<bool (*)(PyObject *, void *&)>> direct_conversions; std::unordered_map<std::type_index, std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators; std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
std::unordered_map<std::string, void *> shared_data;
#if defined(WITH_THREAD) #if defined(WITH_THREAD)
decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
PyInterpreterState *istate = nullptr; PyInterpreterState *istate = nullptr;

View File

@ -21,6 +21,7 @@
#include <initializer_list> #include <initializer_list>
#include <functional> #include <functional>
#include <utility> #include <utility>
#include <typeindex>
#if defined(_MSC_VER) #if defined(_MSC_VER)
# pragma warning(push) # pragma warning(push)
@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy {
PyObject *base; PyObject *base;
}; };
struct numpy_type_info {
PyObject* dtype_ptr;
std::string format_str;
};
struct numpy_internals {
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
auto it = registered_dtypes.find(std::type_index(typeid(T)));
if (it != registered_dtypes.end())
return &(it->second);
if (throw_if_missing)
pybind11_fail(std::string("NumPy type info missing for ") + typeid(T).name());
return nullptr;
}
};
inline PYBIND11_NOINLINE numpy_internals* load_numpy_internals() {
auto& shared_data = detail::get_internals().shared_data;
auto it = shared_data.find("numpy_internals");
if (it != shared_data.end())
return (numpy_internals *)it->second;
auto ptr = new numpy_internals();
shared_data["numpy_internals"] = ptr;
return ptr;
}
inline numpy_internals& get_numpy_internals() {
static numpy_internals* ptr = load_numpy_internals();
return *ptr;
}
struct npy_api { struct npy_api {
enum constants { enum constants {
NPY_C_CONTIGUOUS_ = 0x0001, NPY_C_CONTIGUOUS_ = 0x0001,
@ -661,30 +695,29 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
static PYBIND11_DESCR name() { return _("struct"); } static PYBIND11_DESCR name() { return _("struct"); }
static pybind11::dtype dtype() { static pybind11::dtype dtype() {
if (!dtype_ptr) return object(dtype_ptr(), true);
pybind11_fail("NumPy: unsupported buffer format!");
return object(dtype_ptr, true);
} }
static std::string format() { static std::string format() {
if (!dtype_ptr) static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
pybind11_fail("NumPy: unsupported buffer format!");
return format_str; return format_str;
} }
static void register_dtype(std::initializer_list<field_descriptor> fields) { static void register_dtype(std::initializer_list<field_descriptor> fields) {
if (dtype_ptr) auto& numpy_internals = get_numpy_internals();
if (numpy_internals.get_type_info<T>(false))
pybind11_fail("NumPy: dtype is already registered"); pybind11_fail("NumPy: dtype is already registered");
list names, formats, offsets; list names, formats, offsets;
for (auto field : fields) { for (auto field : fields) {
if (!field.descr) if (!field.descr)
pybind11_fail("NumPy: unsupported field dtype"); pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
field.name + "` @ " + typeid(T).name());
names.append(PYBIND11_STR_TYPE(field.name)); names.append(PYBIND11_STR_TYPE(field.name));
formats.append(field.descr); formats.append(field.descr);
offsets.append(pybind11::int_(field.offset)); offsets.append(pybind11::int_(field.offset));
} }
dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); auto dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
// There is an existing bug in NumPy (as of v1.11): trailing bytes are // There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly // not encoded explicitly into the format string. This will supposedly
@ -695,9 +728,7 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
// strings and will just do it ourselves. // strings and will just do it ourselves.
std::vector<field_descriptor> ordered_fields(fields); std::vector<field_descriptor> ordered_fields(fields);
std::sort(ordered_fields.begin(), ordered_fields.end(), std::sort(ordered_fields.begin(), ordered_fields.end(),
[](const field_descriptor &a, const field_descriptor &b) { [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
return a.offset < b.offset;
});
size_t offset = 0; size_t offset = 0;
std::ostringstream oss; std::ostringstream oss;
oss << "T{"; oss << "T{";
@ -711,44 +742,39 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
if (sizeof(T) > offset) if (sizeof(T) > offset)
oss << (sizeof(T) - offset) << 'x'; oss << (sizeof(T) - offset) << 'x';
oss << '}'; oss << '}';
format_str = oss.str(); auto format_str = oss.str();
// Sanity check: verify that NumPy properly parses our buffer format string // Sanity check: verify that NumPy properly parses our buffer format string
auto& api = npy_api::get(); auto& api = npy_api::get();
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1)); auto arr = array(buffer_info(nullptr, sizeof(T), format_str, 1));
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
pybind11_fail("NumPy: invalid buffer descriptor!"); pybind11_fail("NumPy: invalid buffer descriptor!");
register_direct_converter(); auto tindex = std::type_index(typeid(T));
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
get_internals().direct_conversions[tindex].push_back(direct_converter);
} }
private: private:
static std::string format_str; static PyObject* dtype_ptr() {
static PyObject* dtype_ptr; static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
return ptr;
}
static bool direct_converter(PyObject *obj, void*& value) { static bool direct_converter(PyObject *obj, void*& value) {
auto& api = npy_api::get(); auto& api = npy_api::get();
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
return false; return false;
if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) { if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) { if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
value = ((PyVoidScalarObject_Proxy *) obj)->obval; value = ((PyVoidScalarObject_Proxy *) obj)->obval;
return true; return true;
} }
} }
return false; return false;
} }
static void register_direct_converter() {
get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter);
}
}; };
template <typename T>
std::string npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::format_str;
template <typename T>
PyObject* npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::dtype_ptr = nullptr;
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \ #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
::pybind11::detail::field_descriptor { \ ::pybind11::detail::field_descriptor { \
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \ Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \

View File

@ -18,7 +18,7 @@ def test_format_descriptors():
with pytest.raises(RuntimeError) as excinfo: with pytest.raises(RuntimeError) as excinfo:
get_format_unbound() get_format_unbound()
assert 'unsupported buffer format' in str(excinfo.value) assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))
assert print_format_descriptors() == [ assert print_format_descriptors() == [
"T{=?:x:3x=I:y:=f:z:}", "T{=?:x:3x=I:y:=f:z:}",