From c6a57c10d10a458db1dfd8430d817017629f8b0a Mon Sep 17 00:00:00 2001 From: Jason Rhinelander Date: Tue, 19 Sep 2017 22:12:46 -0300 Subject: [PATCH] Fix dtype string leak `PyArray_DescrConverter_` doesn't steal a reference to the argument, and so the passed arguments shouldn't be `.release()`d. --- include/pybind11/numpy.h | 2 +- tests/test_numpy_dtypes.cpp | 3 +++ tests/test_numpy_dtypes.py | 13 +++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 3755e890..6fd8fdf3 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -440,7 +440,7 @@ public: /// This is essentially the same as calling numpy.dtype(args) in Python. static dtype from_args(object args) { PyObject *ptr = nullptr; - if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr) + if (!detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) || !ptr) throw error_already_set(); return reinterpret_steal(ptr); } diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index ddec851f..266b1fe7 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -448,4 +448,7 @@ TEST_SUBMODULE(numpy_dtypes, m) { // test_register_dtype m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); }); + + // test_str_leak + m.def("dtype_wrapper", [](py::object d) { return py::dtype::from_args(std::move(d)); }); } diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 5f9a9540..4818ca4a 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -293,6 +293,19 @@ def test_register_dtype(): assert 'dtype is already registered' in str(excinfo.value) +@pytest.unsupported_on_pypy +def test_str_leak(): + from sys import getrefcount + fmt = "f4" + pytest.gc_collect() + start = getrefcount(fmt) + d = m.dtype_wrapper(fmt) + assert d is np.dtype("f4") + del d + pytest.gc_collect() + assert getrefcount(fmt) == start + + @pytest.requires_numpy def test_compare_buffer_info(): assert all(m.compare_buffer_info())