diff --git a/include/pybind11/embed.h b/include/pybind11/embed.h index 6e777830..7b5d7cd2 100644 --- a/include/pybind11/embed.h +++ b/include/pybind11/embed.h @@ -12,6 +12,9 @@ #include "pybind11.h" #include "eval.h" +#include +#include + #if defined(PYPY_VERSION) # error Embedding the interpreter is not supported with PyPy #endif @@ -83,29 +86,106 @@ struct embedded_module { } }; +struct wide_char_arg_deleter { + void operator()(wchar_t *ptr) const { +#if PY_VERSION_HEX >= 0x030500f0 + // API docs: https://docs.python.org/3/c-api/sys.html#c.Py_DecodeLocale + PyMem_RawFree(ptr); +#else + delete[] ptr; +#endif + } +}; + +inline wchar_t *widen_chars(const char *safe_arg) { +#if PY_VERSION_HEX >= 0x030500f0 + wchar_t *widened_arg = Py_DecodeLocale(safe_arg, nullptr); +#else + wchar_t *widened_arg = nullptr; +# if defined(HAVE_BROKEN_MBSTOWCS) && HAVE_BROKEN_MBSTOWCS + size_t count = strlen(safe_arg); +# else + size_t count = mbstowcs(nullptr, safe_arg, 0); +# endif + if (count != static_cast(-1)) { + widened_arg = new wchar_t[count + 1]; + mbstowcs(widened_arg, safe_arg, count + 1); + } +#endif + return widened_arg; +} + +/// Python 2.x/3.x-compatible version of `PySys_SetArgv` +inline void set_interpreter_argv(int argc, const char *const *argv, bool add_program_dir_to_path) { + // Before it was special-cased in python 3.8, passing an empty or null argv + // caused a segfault, so we have to reimplement the special case ourselves. + bool special_case = (argv == nullptr || argc <= 0); + + const char *const empty_argv[]{"\0"}; + const char *const *safe_argv = special_case ? empty_argv : argv; + if (special_case) + argc = 1; + + auto argv_size = static_cast(argc); +#if PY_MAJOR_VERSION >= 3 + // SetArgv* on python 3 takes wchar_t, so we have to convert. + std::unique_ptr widened_argv(new wchar_t *[argv_size]); + std::vector> widened_argv_entries; + widened_argv_entries.reserve(argv_size); + for (size_t ii = 0; ii < argv_size; ++ii) { + widened_argv_entries.emplace_back(widen_chars(safe_argv[ii])); + if (!widened_argv_entries.back()) { + // A null here indicates a character-encoding failure or the python + // interpreter out of memory. Give up. + return; + } + widened_argv[ii] = widened_argv_entries.back().get(); + } + + auto pysys_argv = widened_argv.get(); +#else + // python 2.x + std::vector strings{safe_argv, safe_argv + argv_size}; + std::vector char_strings{argv_size}; + for (std::size_t i = 0; i < argv_size; ++i) + char_strings[i] = &strings[i][0]; + char **pysys_argv = char_strings.data(); +#endif + + PySys_SetArgvEx(argc, pysys_argv, static_cast(add_program_dir_to_path)); +} + PYBIND11_NAMESPACE_END(detail) /** \rst Initialize the Python interpreter. No other pybind11 or CPython API functions can be called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The - optional parameter can be used to skip the registration of signal handlers (see the - `Python documentation`_ for details). Calling this function again after the interpreter - has already been initialized is a fatal error. + optional `init_signal_handlers` parameter can be used to skip the registration of + signal handlers (see the `Python documentation`_ for details). Calling this function + again after the interpreter has already been initialized is a fatal error. If initializing the Python interpreter fails, then the program is terminated. (This is controlled by the CPython runtime and is an exception to pybind11's normal behavior of throwing exceptions on errors.) + The remaining optional parameters, `argc`, `argv`, and `add_program_dir_to_path` are + used to populate ``sys.argv`` and ``sys.path``. + See the |PySys_SetArgvEx documentation|_ for details. + .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx + .. |PySys_SetArgvEx documentation| replace:: ``PySys_SetArgvEx`` documentation + .. _PySys_SetArgvEx documentation: https://docs.python.org/3/c-api/init.html#c.PySys_SetArgvEx \endrst */ -inline void initialize_interpreter(bool init_signal_handlers = true) { +inline void initialize_interpreter(bool init_signal_handlers = true, + int argc = 0, + const char *const *argv = nullptr, + bool add_program_dir_to_path = true) { if (Py_IsInitialized() != 0) pybind11_fail("The interpreter is already running"); Py_InitializeEx(init_signal_handlers ? 1 : 0); - // Make .py files in the working directory available by default - module_::import("sys").attr("path").cast().append("."); + detail::set_interpreter_argv(argc, argv, add_program_dir_to_path); } /** \rst @@ -167,6 +247,8 @@ inline void finalize_interpreter() { Scope guard version of `initialize_interpreter` and `finalize_interpreter`. This a move-only guard and only a single instance can exist. + See `initialize_interpreter` for a discussion of its constructor arguments. + .. code-block:: cpp #include @@ -178,8 +260,11 @@ inline void finalize_interpreter() { \endrst */ class scoped_interpreter { public: - scoped_interpreter(bool init_signal_handlers = true) { - initialize_interpreter(init_signal_handlers); + scoped_interpreter(bool init_signal_handlers = true, + int argc = 0, + const char *const *argv = nullptr, + bool add_program_dir_to_path = true) { + initialize_interpreter(init_signal_handlers, argc, argv, add_program_dir_to_path); } scoped_interpreter(const scoped_interpreter &) = delete; diff --git a/tests/test_embed/test_interpreter.cpp b/tests/test_embed/test_interpreter.cpp index b40ff481..78b64be6 100644 --- a/tests/test_embed/test_interpreter.cpp +++ b/tests/test_embed/test_interpreter.cpp @@ -23,6 +23,7 @@ public: std::string the_message() const { return message; } virtual int the_answer() const = 0; + virtual std::string argv0() const = 0; private: std::string message; @@ -32,6 +33,7 @@ class PyWidget final : public Widget { using Widget::Widget; int the_answer() const override { PYBIND11_OVERRIDE_PURE(int, Widget, the_answer); } + std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); } }; PYBIND11_EMBEDDED_MODULE(widget_module, m) { @@ -299,3 +301,25 @@ TEST_CASE("Reload module from file") { result = module_.attr("test")().cast(); REQUIRE(result == 2); } + +TEST_CASE("sys.argv gets initialized properly") { + py::finalize_interpreter(); + { + py::scoped_interpreter default_scope; + auto module = py::module::import("test_interpreter"); + auto py_widget = module.attr("DerivedWidget")("The question"); + const auto &cpp_widget = py_widget.cast(); + REQUIRE(cpp_widget.argv0().empty()); + } + + { + char *argv[] = {strdup("a.out")}; + py::scoped_interpreter argv_scope(true, 1, argv); + free(argv[0]); + auto module = py::module::import("test_interpreter"); + auto py_widget = module.attr("DerivedWidget")("The question"); + const auto &cpp_widget = py_widget.cast(); + REQUIRE(cpp_widget.argv0() == "a.out"); + } + py::initialize_interpreter(); +} diff --git a/tests/test_embed/test_interpreter.py b/tests/test_embed/test_interpreter.py index 6174ede4..5ab55a4b 100644 --- a/tests/test_embed/test_interpreter.py +++ b/tests/test_embed/test_interpreter.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import sys + from widget_module import Widget @@ -8,3 +10,6 @@ class DerivedWidget(Widget): def the_answer(self): return 42 + + def argv0(self): + return sys.argv[0]