diff --git a/include/pybind11/detail/class.h b/include/pybind11/detail/class.h index b1916fcd..ffdfefe7 100644 --- a/include/pybind11/detail/class.h +++ b/include/pybind11/detail/class.h @@ -586,6 +586,9 @@ inline PyObject* make_new_python_type(const type_record &rec) { type->tp_as_number = &heap_type->as_number; type->tp_as_sequence = &heap_type->as_sequence; type->tp_as_mapping = &heap_type->as_mapping; +#if PY_VERSION_HEX >= 0x03050000 + type->tp_as_async = &heap_type->as_async; +#endif /* Flags */ type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fb6776f2..765c47ad 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -26,6 +26,7 @@ endif() # Full set of test files (you can override these; see below) set(PYBIND11_TEST_FILES + test_async.cpp test_buffers.cpp test_builtin_casters.cpp test_call_policies.cpp @@ -71,6 +72,13 @@ if (PYBIND11_TEST_OVERRIDE) set(PYBIND11_TEST_FILES ${PYBIND11_TEST_OVERRIDE}) endif() +# Skip test_async for Python < 3.5 +list(FIND PYBIND11_TEST_FILES test_async.cpp PYBIND11_TEST_FILES_ASYNC_I) +if((PYBIND11_TEST_FILES_ASYNC_I GREATER -1) AND ("${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}" VERSION_LESS 3.5)) + message(STATUS "Skipping test_async because Python version ${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR} < 3.5") + list(REMOVE_AT PYBIND11_TEST_FILES ${PYBIND11_TEST_FILES_ASYNC_I}) +endif() + string(REPLACE ".cpp" ".py" PYBIND11_PYTEST_FILES "${PYBIND11_TEST_FILES}") # Contains the set of test files that require pybind11_cross_module_tests to be diff --git a/tests/conftest.py b/tests/conftest.py index 55d9d0d5..57f681c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,11 @@ _unicode_marker = re.compile(r'u(\'[^\']*\')') _long_marker = re.compile(r'([0-9])L') _hexadecimal = re.compile(r'0x[0-9a-fA-F]+') +# test_async.py requires support for async and await +collect_ignore = [] +if sys.version_info[:2] < (3, 5): + collect_ignore.append("test_async.py") + def _strip_and_dedent(s): """For triple-quote strings""" diff --git a/tests/test_async.cpp b/tests/test_async.cpp new file mode 100644 index 00000000..f0ad0d53 --- /dev/null +++ b/tests/test_async.cpp @@ -0,0 +1,26 @@ +/* + tests/test_async.cpp -- __await__ support + + Copyright (c) 2019 Google Inc. + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include "pybind11_tests.h" + +TEST_SUBMODULE(async_module, m) { + struct DoesNotSupportAsync {}; + py::class_(m, "DoesNotSupportAsync") + .def(py::init<>()); + struct SupportsAsync {}; + py::class_(m, "SupportsAsync") + .def(py::init<>()) + .def("__await__", [](const SupportsAsync& self) -> py::object { + static_cast(self); + py::object loop = py::module::import("asyncio.events").attr("get_event_loop")(); + py::object f = loop.attr("create_future")(); + f.attr("set_result")(5); + return f.attr("__await__")(); + }); +} diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 00000000..e1c959d6 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,23 @@ +import asyncio +import pytest +from pybind11_tests import async_module as m + + +@pytest.fixture +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +async def get_await_result(x): + return await x + + +def test_await(event_loop): + assert 5 == event_loop.run_until_complete(get_await_result(m.SupportsAsync())) + + +def test_await_missing(event_loop): + with pytest.raises(TypeError): + event_loop.run_until_complete(get_await_result(m.DoesNotSupportAsync()))