mirror of
https://github.com/RYDE-WORK/pybind11.git
synced 2026-02-04 14:33:21 +08:00
Move register_dtype() outside of the template
(avoid code bloat if possible)
This commit is contained in:
parent
f95fda0eb2
commit
cc8ff16547
@ -81,14 +81,18 @@ struct numpy_type_info {
|
|||||||
struct numpy_internals {
|
struct numpy_internals {
|
||||||
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
|
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
|
||||||
|
|
||||||
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
|
numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
|
||||||
auto it = registered_dtypes.find(std::type_index(typeid(T)));
|
auto it = registered_dtypes.find(std::type_index(tinfo));
|
||||||
if (it != registered_dtypes.end())
|
if (it != registered_dtypes.end())
|
||||||
return &(it->second);
|
return &(it->second);
|
||||||
if (throw_if_missing)
|
if (throw_if_missing)
|
||||||
pybind11_fail(std::string("NumPy type info missing for ") + typeid(T).name());
|
pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
|
||||||
|
return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
|
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
|
||||||
@ -686,34 +690,25 @@ struct field_descriptor {
|
|||||||
dtype descr;
|
dtype descr;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
inline PYBIND11_NOINLINE void register_structured_dtype(
|
||||||
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
|
const std::initializer_list<field_descriptor>& fields,
|
||||||
static PYBIND11_DESCR name() { return _("struct"); }
|
const std::type_info& tinfo, size_t itemsize,
|
||||||
|
bool (*direct_converter)(PyObject *, void *&))
|
||||||
static pybind11::dtype dtype() {
|
{
|
||||||
return object(dtype_ptr(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string format() {
|
|
||||||
static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
|
|
||||||
return format_str;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void register_dtype(std::initializer_list<field_descriptor> fields) {
|
|
||||||
auto& numpy_internals = get_numpy_internals();
|
auto& numpy_internals = get_numpy_internals();
|
||||||
if (numpy_internals.get_type_info<T>(false))
|
if (numpy_internals.get_type_info(tinfo, false))
|
||||||
pybind11_fail("NumPy: dtype is already registered");
|
pybind11_fail("NumPy: dtype is already registered");
|
||||||
|
|
||||||
list names, formats, offsets;
|
list names, formats, offsets;
|
||||||
for (auto field : fields) {
|
for (auto field : fields) {
|
||||||
if (!field.descr)
|
if (!field.descr)
|
||||||
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
|
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
|
||||||
field.name + "` @ " + typeid(T).name());
|
field.name + "` @ " + tinfo.name());
|
||||||
names.append(PYBIND11_STR_TYPE(field.name));
|
names.append(PYBIND11_STR_TYPE(field.name));
|
||||||
formats.append(field.descr);
|
formats.append(field.descr);
|
||||||
offsets.append(pybind11::int_(field.offset));
|
offsets.append(pybind11::int_(field.offset));
|
||||||
}
|
}
|
||||||
auto dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
|
auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
|
||||||
|
|
||||||
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
|
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
|
||||||
// not encoded explicitly into the format string. This will supposedly
|
// not encoded explicitly into the format string. This will supposedly
|
||||||
@ -735,22 +730,40 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
|
|||||||
oss << '=' << field.format << ':' << field.name << ':';
|
oss << '=' << field.format << ':' << field.name << ':';
|
||||||
offset = field.offset + field.size;
|
offset = field.offset + field.size;
|
||||||
}
|
}
|
||||||
if (sizeof(T) > offset)
|
if (itemsize > offset)
|
||||||
oss << (sizeof(T) - offset) << 'x';
|
oss << (itemsize - offset) << 'x';
|
||||||
oss << '}';
|
oss << '}';
|
||||||
auto format_str = oss.str();
|
auto format_str = oss.str();
|
||||||
|
|
||||||
// Sanity check: verify that NumPy properly parses our buffer format string
|
// Sanity check: verify that NumPy properly parses our buffer format string
|
||||||
auto& api = npy_api::get();
|
auto& api = npy_api::get();
|
||||||
auto arr = array(buffer_info(nullptr, sizeof(T), format_str, 1));
|
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
|
||||||
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
|
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
|
||||||
pybind11_fail("NumPy: invalid buffer descriptor!");
|
pybind11_fail("NumPy: invalid buffer descriptor!");
|
||||||
|
|
||||||
auto tindex = std::type_index(typeid(T));
|
auto tindex = std::type_index(tinfo);
|
||||||
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
|
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
|
||||||
get_internals().direct_conversions[tindex].push_back(direct_converter);
|
get_internals().direct_conversions[tindex].push_back(direct_converter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
|
||||||
|
static PYBIND11_DESCR name() { return _("struct"); }
|
||||||
|
|
||||||
|
static pybind11::dtype dtype() {
|
||||||
|
return object(dtype_ptr(), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string format() {
|
||||||
|
static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
|
||||||
|
return format_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
|
||||||
|
register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
|
||||||
|
sizeof(T), &direct_converter);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static PyObject* dtype_ptr() {
|
static PyObject* dtype_ptr() {
|
||||||
static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
|
static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user