@@ -2546,82 +2546,189 @@ toggle_reftrace_printer(PyObject *ob, PyObject *arg)
25462546 Py_RETURN_NONE ;
25472547}
25482548
2549- static PyObject *
2550- test_interp_refcount (PyObject * self , PyObject * unused )
2549+ static PyInterpreterRef
2550+ get_strong_ref (void )
2551+ {
2552+ PyInterpreterRef ref ;
2553+ if (PyInterpreterRef_Get (& ref ) < 0 ) {
2554+ Py_FatalError ("strong reference should not have failed" );
2555+ }
2556+ return ref ;
2557+ }
2558+
2559+ static void
2560+ test_interp_ref_common (void )
25512561{
25522562 PyInterpreterState * interp = PyInterpreterState_Get ();
2553- PyInterpreterRef ref1 ;
2554- PyInterpreterRef ref2 ;
2555-
2556- // Reference counts are technically 0 by default
2557- assert (_PyInterpreterState_Refcount (interp ) == 0 );
2558- ref1 = PyInterpreterRef_Get ();
2559- assert (_PyInterpreterState_Refcount (interp ) == 1 );
2560- ref2 = PyInterpreterRef_Get ();
2561- assert (_PyInterpreterState_Refcount (interp ) == 2 );
2562- PyInterpreterRef_Close (ref1 );
2563- assert (_PyInterpreterState_Refcount (interp ) == 1 );
2564- PyInterpreterRef_Close (ref2 );
2565- assert (_PyInterpreterState_Refcount (interp ) == 0 );
2566-
2567- ref1 = PyInterpreterRef_Get ();
2568- ref2 = PyInterpreterRef_Dup (ref1 );
2569- assert (_PyInterpreterState_Refcount (interp ) == 2 );
2570- assert (PyInterpreterRef_AsInterpreter (ref1 ) == interp );
2571- assert (PyInterpreterRef_AsInterpreter (ref2 ) == interp );
2572- PyInterpreterRef_Close (ref1 );
2573- PyInterpreterRef_Close (ref2 );
2574- assert (_PyInterpreterState_Refcount (interp ) == 0 );
2563+ PyInterpreterRef ref = get_strong_ref ();
2564+ assert (PyInterpreterRef_AsInterpreter (ref ) == interp );
25752565
2576- Py_RETURN_NONE ;
2566+ PyInterpreterRef ref_2 = PyInterpreterRef_Dup (ref );
2567+ assert (PyInterpreterRef_AsInterpreter (ref_2 ) == interp );
2568+
2569+ // We can close the references in any order
2570+ PyInterpreterRef_Close (ref );
2571+ PyInterpreterRef_Close (ref_2 );
25772572}
25782573
25792574static PyObject *
2580- test_interp_weak_ref (PyObject * self , PyObject * unused )
2575+ test_interpreter_refs (PyObject * self , PyObject * unused )
25812576{
2582- PyInterpreterState * interp = PyInterpreterState_Get ();
2583- PyInterpreterWeakRef wref = PyInterpreterWeakRef_Get ();
2584- assert (_PyInterpreterState_Refcount (interp ) == 0 );
2577+ // Test the main interpreter
2578+ test_interp_ref_common ();
25852579
2586- PyInterpreterRef ref ;
2587- int res = PyInterpreterWeakRef_AsStrong (wref , & ref );
2588- assert (res == 0 );
2589- assert (PyInterpreterRef_AsInterpreter (ref ) == interp );
2590- assert (_PyInterpreterState_Refcount (interp ) == 1 );
2591- PyInterpreterWeakRef_Close (wref );
2592- PyInterpreterRef_Close (ref );
2580+ // Test a (legacy) subinterpreter
2581+ PyThreadState * save_tstate = PyThreadState_Swap (NULL );
2582+ PyThreadState * interp_tstate = Py_NewInterpreter ();
2583+ test_interp_ref_common ();
2584+ Py_EndInterpreter (interp_tstate );
2585+
2586+ // Test an isolated subinterpreter
2587+ PyInterpreterConfig config = {
2588+ .gil = PyInterpreterConfig_OWN_GIL ,
2589+ .check_multi_interp_extensions = 1
2590+ };
25932591
2592+ PyThreadState * isolated_interp_tstate ;
2593+ PyStatus status = Py_NewInterpreterFromConfig (& isolated_interp_tstate , & config );
2594+ if (PyStatus_Exception (status )) {
2595+ PyErr_SetString (PyExc_RuntimeError , "interpreter creation failed" );
2596+ return NULL ;
2597+ }
2598+
2599+ test_interp_ref_common ();
2600+ Py_EndInterpreter (isolated_interp_tstate );
2601+ PyThreadState_Swap (save_tstate );
25942602 Py_RETURN_NONE ;
25952603}
25962604
25972605static PyObject *
2598- test_interp_ensure (PyObject * self , PyObject * unused )
2606+ test_thread_state_ensure_nested (PyObject * self , PyObject * unused )
25992607{
2600- PyInterpreterState * interp = PyInterpreterState_Get ();
2601- PyInterpreterRef ref = PyInterpreterRef_Get ();
2608+ PyInterpreterRef ref = get_strong_ref ();
26022609 PyThreadState * save_tstate = PyThreadState_Swap (NULL );
2603- PyThreadState * tstate = Py_NewInterpreter ();
2604- PyInterpreterRef sub_ref = PyInterpreterRef_Get ();
2605- PyInterpreterState * subinterp = PyThreadState_GetInterpreter (tstate );
2610+ assert (PyGILState_GetThisThreadState () == save_tstate );
26062611
26072612 for (int i = 0 ; i < 10 ; ++ i ) {
2608- int res = PyThreadState_Ensure (ref );
2609- assert (res == 0 );
2610- assert (PyInterpreterState_Get () == interp );
2613+ // Test reactivation of the detached tstate.
2614+ if (PyThreadState_Ensure (ref ) < 0 ) {
2615+ PyInterpreterRef_Close (ref );
2616+ return PyErr_NoMemory ();
2617+ }
2618+
2619+ // No new thread state should've been created.
2620+ assert (PyThreadState_Get () == save_tstate );
2621+ PyThreadState_Release ();
26112622 }
26122623
2624+ assert (PyThreadState_GetUnchecked () == NULL );
2625+
2626+ // Similarly, test ensuring with deep nesting and *then* releasing.
2627+ // If the (detached) gilstate matches the interpreter, then it shouldn't
2628+ // create a new thread state.
26132629 for (int i = 0 ; i < 10 ; ++ i ) {
2614- int res = PyThreadState_Ensure (sub_ref );
2615- assert (res == 0 );
2616- assert (PyInterpreterState_Get () == subinterp );
2630+ if (PyThreadState_Ensure (ref ) < 0 ) {
2631+ // This will technically leak other thread states, but it doesn't
2632+ // matter because this is a test.
2633+ PyInterpreterRef_Close (ref );
2634+ return PyErr_NoMemory ();
2635+ }
2636+
2637+ assert (PyThreadState_Get () == save_tstate );
26172638 }
26182639
2619- for (int i = 0 ; i < 20 ; ++ i ) {
2640+ for (int i = 0 ; i < 10 ; ++ i ) {
2641+ assert (PyThreadState_Get () == save_tstate );
26202642 PyThreadState_Release ();
26212643 }
26222644
2645+ assert (PyThreadState_GetUnchecked () == NULL );
2646+ PyInterpreterRef_Close (ref );
2647+ PyThreadState_Swap (save_tstate );
2648+ Py_RETURN_NONE ;
2649+ }
2650+
2651+ static PyObject *
2652+ test_thread_state_ensure_crossinterp (PyObject * self , PyObject * unused )
2653+ {
2654+ PyInterpreterRef ref = get_strong_ref ();
2655+ PyThreadState * save_tstate = PyThreadState_Swap (NULL );
2656+ PyThreadState * interp_tstate = Py_NewInterpreter ();
2657+ if (interp_tstate == NULL ) {
2658+ PyInterpreterRef_Close (ref );
2659+ return PyErr_NoMemory ();
2660+ }
2661+
2662+ /* This should create a new thread state for the calling interpreter, *not*
2663+ reactivate the old one. In a real-world scenario, this would arise in
2664+ something like this:
2665+
2666+ def some_func():
2667+ import something
2668+ # This re-enters the main interpreter, but we
2669+ # shouldn't have access to prior thread-locals.
2670+ something.call_something()
2671+
2672+ interp = interpreters.create()
2673+ interp.exec(some_func)
2674+ */
2675+ if (PyThreadState_Ensure (ref ) < 0 ) {
2676+ PyInterpreterRef_Close (ref );
2677+ return PyErr_NoMemory ();
2678+ }
2679+
2680+ PyThreadState * ensured_tstate = PyThreadState_Get ();
2681+ assert (ensured_tstate != save_tstate );
2682+ assert (PyInterpreterState_Get () == PyInterpreterRef_AsInterpreter (ref ));
2683+ assert (PyGILState_GetThisThreadState () == ensured_tstate );
2684+
2685+ // Now though, we should reactivate the thread state
2686+ if (PyThreadState_Ensure (ref ) < 0 ) {
2687+ PyInterpreterRef_Close (ref );
2688+ return PyErr_NoMemory ();
2689+ }
2690+
2691+ assert (PyThreadState_Get () == ensured_tstate );
2692+ PyThreadState_Release ();
2693+
2694+ // Ensure that we're restoring the prior thread state
2695+ PyThreadState_Release ();
2696+ assert (PyThreadState_Get () == interp_tstate );
2697+ assert (PyGILState_GetThisThreadState () == interp_tstate );
2698+
2699+ PyThreadState_Swap (interp_tstate );
2700+ Py_EndInterpreter (interp_tstate );
2701+
26232702 PyInterpreterRef_Close (ref );
2624- PyInterpreterRef_Close (sub_ref );
2703+ PyThreadState_Swap (save_tstate );
2704+ Py_RETURN_NONE ;
2705+ }
2706+
2707+ static PyObject *
2708+ test_weak_interpreter_ref_after_shutdown (PyObject * self , PyObject * unused )
2709+ {
2710+ PyThreadState * save_tstate = PyThreadState_Swap (NULL );
2711+ PyInterpreterWeakRef wref ;
2712+ PyThreadState * interp_tstate = Py_NewInterpreter ();
2713+ if (interp_tstate == NULL ) {
2714+ return PyErr_NoMemory ();
2715+ }
2716+
2717+ int res = PyInterpreterWeakRef_Get (& wref );
2718+ (void )res ;
2719+ assert (res == 0 );
2720+
2721+ // As a sanity check, ensure that the weakref actually works
2722+ PyInterpreterRef ref ;
2723+ res = PyInterpreterWeakRef_AsStrong (wref , & ref );
2724+ assert (res == 0 );
2725+ PyInterpreterRef_Close (ref );
2726+
2727+ // Now, destroy the interpreter and try to acquire a weak reference.
2728+ // It should fail.
2729+ Py_EndInterpreter (interp_tstate );
2730+ res = PyInterpreterWeakRef_AsStrong (wref , & ref );
2731+ assert (res == -1 );
26252732
26262733 PyThreadState_Swap (save_tstate );
26272734 Py_RETURN_NONE ;
@@ -2721,9 +2828,10 @@ static PyMethodDef TestMethods[] = {
27212828 {"test_atexit" , test_atexit , METH_NOARGS },
27222829 {"code_offset_to_line" , _PyCFunction_CAST (code_offset_to_line ), METH_FASTCALL },
27232830 {"toggle_reftrace_printer" , toggle_reftrace_printer , METH_O },
2724- {"test_interp_refcount" , test_interp_refcount , METH_NOARGS },
2725- {"test_interp_weak_ref" , test_interp_weak_ref , METH_NOARGS },
2726- {"test_interp_ensure" , test_interp_ensure , METH_NOARGS },
2831+ {"test_interpreter_refs" , test_interpreter_refs , METH_NOARGS },
2832+ {"test_thread_state_ensure_nested" , test_thread_state_ensure_nested , METH_NOARGS },
2833+ {"test_thread_state_ensure_crossinterp" , test_thread_state_ensure_crossinterp , METH_NOARGS },
2834+ {"test_weak_interpreter_ref_after_shutdown" , test_weak_interpreter_ref_after_shutdown , METH_NOARGS },
27272835 {NULL , NULL } /* sentinel */
27282836};
27292837
0 commit comments