From d04a9b1d9e6a0aa3601972f1291d181377759d14 Mon Sep 17 00:00:00 2001
From: Borja Zarco <bzarco@google.com>
Date: Fri, 8 Dec 2017 17:35:59 -0500
Subject: [PATCH] Use `PyGILState_GetThisThreadState` when using
 gil_scoped_acquire.

This avoids GIL deadlocking when pybind11 tries to acquire the GIL in a thread that already acquired it using standard Python API (e.g. when running from a Python thread).
---
 include/pybind11/pybind11.h |  9 +++++
 tests/CMakeLists.txt        |  1 +
 tests/test_gil_scoped.cpp   | 43 ++++++++++++++++++++
 tests/test_gil_scoped.py    | 80 +++++++++++++++++++++++++++++++++++++
 4 files changed, 133 insertions(+)
 create mode 100644 tests/test_gil_scoped.cpp
 create mode 100644 tests/test_gil_scoped.py

diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index e986d007a..5f5f5c82a 100644
--- include/pybind11/pybind11.h
+++ include/pybind11/pybind11.h
@@ -1749,6 +1749,15 @@ class gil_scoped_acquire {
         auto const &internals = detail::get_internals();
         tstate = (PyThreadState *) PYBIND11_TLS_GET_VALUE(internals.tstate);
 
+        if (!tstate) {
+            /* Check if the GIL was acquired using the PyGILState_* API instead (e.g. if
+               calling from a Python thread). Since we use a different key, this ensures
+               we don't create a new thread state and deadlock in PyEval_AcquireThread
+               below. Note we don't save this state with internals.tstate, since we don't
+               create it we would fail to clear it (its reference count should be > 0). */
+            tstate = PyGILState_GetThisThreadState();
+        }
+
         if (!tstate) {
             tstate = PyThreadState_New(internals.istate);
             #if !defined(NDEBUG)
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index b5e0b526e..a31d5b8b7 100644
--- tests/CMakeLists.txt
+++ tests/CMakeLists.txt
@@ -40,6 +40,7 @@ set(PYBIND11_TEST_FILES
   test_eval.cpp
   test_exceptions.cpp
   test_factory_constructors.cpp
+  test_gil_scoped.cpp
   test_iostream.cpp
   test_kwargs_and_defaults.cpp
   test_local_bindings.cpp
diff --git a/tests/test_gil_scoped.cpp b/tests/test_gil_scoped.cpp
new file mode 100644
index 000000000..a94b7a2e2
--- /dev/null
+++ tests/test_gil_scoped.cpp
@@ -0,0 +1,43 @@
+/*
+    tests/test_gil_scoped.cpp -- acquire and release gil
+
+    Copyright (c) 2017 Borja Zarco (Google LLC) <bzarco@google.com>
+
+    All rights reserved. Use of this source code is governed by a
+    BSD-style license that can be found in the LICENSE file.
+*/
+
+#include "pybind11_tests.h"
+#include <pybind11/functional.h>
+
+
+class VirtClass  {
+public:
+    virtual void virtual_func() {}
+    virtual void pure_virtual_func() = 0;
+};
+
+class PyVirtClass : public VirtClass {
+    void virtual_func() override {
+        PYBIND11_OVERLOAD(void, VirtClass, virtual_func,);
+    }
+    void pure_virtual_func() override {
+        PYBIND11_OVERLOAD_PURE(void, VirtClass, pure_virtual_func,);
+    }
+};
+
+TEST_SUBMODULE(gil_scoped, m) {
+  py::class_<VirtClass, PyVirtClass>(m, "VirtClass")
+      .def(py::init<>())
+      .def("virtual_func", &VirtClass::virtual_func)
+      .def("pure_virtual_func", &VirtClass::pure_virtual_func);
+
+    m.def("test_callback_py_obj",
+          [](py::object func) { func(); });
+    m.def("test_callback_std_func",
+          [](const std::function<void()> &func) { func(); });
+    m.def("test_callback_virtual_func",
+          [](VirtClass &virt) { virt.virtual_func(); });
+    m.def("test_callback_pure_virtual_func",
+          [](VirtClass &virt) { virt.pure_virtual_func(); });
+}
diff --git a/tests/test_gil_scoped.py b/tests/test_gil_scoped.py
new file mode 100644
index 000000000..5e702431f
--- /dev/null
+++ tests/test_gil_scoped.py
@@ -0,0 +1,80 @@
+import multiprocessing
+import threading
+from pybind11_tests import gil_scoped as m
+
+
+def _run_in_process(target, *args, **kwargs):
+    """Runs target in process and returns its exitcode after 1s (None if still alive)."""
+    process = multiprocessing.Process(target=target, args=args, kwargs=kwargs)
+    process.daemon = True
+    try:
+        process.start()
+        # Do not need to wait much, 1s should be more than enough.
+        process.join(timeout=1)
+        return process.exitcode
+    finally:
+        if process.is_alive():
+            process.terminate()
+
+
+def _python_to_cpp_to_python():
+    """Calls different C++ functions that come back to Python."""
+    class ExtendedVirtClass(m.VirtClass):
+        def virtual_func(self):
+            pass
+
+        def pure_virtual_func(self):
+            pass
+
+    extended = ExtendedVirtClass()
+    m.test_callback_py_obj(lambda: None)
+    m.test_callback_std_func(lambda: None)
+    m.test_callback_virtual_func(extended)
+    m.test_callback_pure_virtual_func(extended)
+
+
+def _python_to_cpp_to_python_from_threads(num_threads, parallel=False):
+    """Calls different C++ functions that come back to Python, from Python threads."""
+    threads = []
+    for _ in range(num_threads):
+        thread = threading.Thread(target=_python_to_cpp_to_python)
+        thread.daemon = True
+        thread.start()
+        if parallel:
+            threads.append(thread)
+        else:
+            thread.join()
+    for thread in threads:
+        thread.join()
+
+
+def test_python_to_cpp_to_python_from_thread():
+    """Makes sure there is no GIL deadlock when running in a thread.
+
+    It runs in a separate process to be able to stop and assert if it deadlocks.
+    """
+    assert _run_in_process(_python_to_cpp_to_python_from_threads, 1) == 0
+
+
+def test_python_to_cpp_to_python_from_thread_multiple_parallel():
+    """Makes sure there is no GIL deadlock when running in a thread multiple times in parallel.
+
+    It runs in a separate process to be able to stop and assert if it deadlocks.
+    """
+    assert _run_in_process(_python_to_cpp_to_python_from_threads, 8, parallel=True) == 0
+
+
+def test_python_to_cpp_to_python_from_thread_multiple_sequential():
+    """Makes sure there is no GIL deadlock when running in a thread multiple times sequentially.
+
+    It runs in a separate process to be able to stop and assert if it deadlocks.
+    """
+    assert _run_in_process(_python_to_cpp_to_python_from_threads, 8, parallel=False) == 0
+
+
+def test_python_to_cpp_to_python_from_process():
+    """Makes sure there is no GIL deadlock when using processes.
+
+    This test is for completion, but it was never an issue.
+    """
+    assert _run_in_process(_python_to_cpp_to_python) == 0