Skip to content

Commit 6313b3d

Browse files
committed
Add iter_locked
1 parent 4f2f780 commit 6313b3d

5 files changed

Lines changed: 260 additions & 4 deletions

File tree

Doc/library/threading.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,20 @@ This module defines the following functions:
144144
of the result, even when terminated.
145145

146146

147+
.. function:: locked_iter(iterable)
148+
149+
Make an iterator thread-safe.
150+
151+
Roughly equivalent to::
152+
153+
class locked_iter(Iterator):
154+
def __init__(self, it):
155+
self._it = iter(it)
156+
self._lock = Lock()
157+
def __next__(self):
158+
with self._lock:
159+
return next(self._it)
160+
147161
.. function:: main_thread()
148162

149163
Return the main :class:`Thread` object. In normal conditions, the
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import unittest
2+
from threading import Thread, Barrier, iter_locked
3+
from test.support import threading_helper
4+
5+
6+
threading_helper.requires_working_threading(module=True)
7+
8+
class non_atomic_iterator:
9+
10+
def __init__(self, it):
11+
self.it = iter(it)
12+
13+
def __iter__(self):
14+
return self
15+
16+
def __next__(self):
17+
a = next(self.it)
18+
b = next(self.it)
19+
return a, b
20+
21+
def count():
22+
i = 0
23+
while True:
24+
i += 1
25+
yield i
26+
27+
class iter_lockedThreading(unittest.TestCase):
28+
29+
@threading_helper.reap_threads
30+
def test_iter_locked(self):
31+
number_of_threads = 10
32+
number_of_iterations = 10
33+
barrier = Barrier(number_of_threads)
34+
def work(it):
35+
while True:
36+
try:
37+
a, b = next(it)
38+
assert a + 1 == b
39+
except StopIteration:
40+
break
41+
42+
data = tuple(range(400))
43+
for it in range(number_of_iterations):
44+
iter_locked_iterator = iter_locked(non_atomic_iterator(data,))
45+
worker_threads = []
46+
for ii in range(number_of_threads):
47+
worker_threads.append(
48+
Thread(target=work, args=[iter_locked_iterator]))
49+
50+
with threading_helper.start_threads(worker_threads):
51+
pass
52+
53+
barrier.reset()
54+
55+
@threading_helper.reap_threads
56+
def test_iter_locked_generator(self):
57+
number_of_threads = 5
58+
number_of_iterations = 4
59+
barrier = Barrier(number_of_threads)
60+
def work(it):
61+
barrier.wait()
62+
for _ in range(1_000):
63+
try:
64+
next(it)
65+
except StopIteration:
66+
break
67+
68+
for it in range(number_of_iterations):
69+
generator = iter_locked(count())
70+
worker_threads = []
71+
for ii in range(number_of_threads):
72+
worker_threads.append(
73+
Thread(target=work, args=[generator]))
74+
75+
with threading_helper.start_threads(worker_threads):
76+
pass
77+
78+
barrier.reset()
79+
80+
if __name__ == "__main__":
81+
unittest.main()

Lib/test/test_threading.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2416,6 +2416,20 @@ def run_last():
24162416
self.assertIn("RuntimeError: can't register atexit after shutdown",
24172417
err.decode())
24182418

2419+
class LockedIterTests(unittest.TestCase):
2420+
2421+
def test_locked_iter(self):
2422+
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
2423+
for g in (G, I, Ig, S, L, R):
2424+
seq = list(g(s))
2425+
expected = seq
2426+
actual = list(serialize(g(s)))
2427+
self.assertEqual(actual, expected)
2428+
self.assertRaises(TypeError, serialize, X(s))
2429+
self.assertRaises(TypeError, serialize, N(s))
2430+
self.assertRaises(ZeroDivisionError, list, serialize(E(s)))
2431+
for arg in [1, True, sys]:
2432+
self.assertRaises(TypeError, serialize, arg)
24192433

24202434
if __name__ == "__main__":
24212435
unittest.main()

Modules/_threadmodule.c

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# include <signal.h> // SIGINT
2020
#endif
2121

22-
#include "clinic/_threadmodule.c.h"
2322

2423
// ThreadError is just an alias to PyExc_RuntimeError
2524
#define ThreadError PyExc_RuntimeError
@@ -30,6 +29,7 @@ static struct PyModuleDef thread_module;
3029
// Module state
3130
typedef struct {
3231
PyTypeObject *excepthook_type;
32+
PyTypeObject *iter_locked_type;
3333
PyTypeObject *lock_type;
3434
PyTypeObject *local_type;
3535
PyTypeObject *local_dummy_type;
@@ -48,6 +48,17 @@ get_thread_state(PyObject *module)
4848
return (thread_module_state *)state;
4949
}
5050

51+
static inline thread_module_state *
52+
find_state_by_type(PyTypeObject *tp)
53+
{
54+
PyObject *mod = PyType_GetModuleByDef(tp, &thread_module);
55+
assert(mod != NULL);
56+
return get_thread_state(mod);
57+
}
58+
59+
#define clinic_state() (find_state_by_type(type))
60+
#include "clinic/_threadmodule.c.h"
61+
#undef clinic_state
5162

5263
#ifdef MS_WINDOWS
5364
typedef HRESULT (WINAPI *PF_GET_THREAD_DESCRIPTION)(HANDLE, PCWSTR*);
@@ -59,8 +70,10 @@ static PF_SET_THREAD_DESCRIPTION pSetThreadDescription = NULL;
5970

6071
/*[clinic input]
6172
module _thread
73+
class _thread.iter_locked "iter_locked_object *" "clinic_state()->iter_locked_type"
74+
6275
[clinic start generated code]*/
63-
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=be8dbe5cc4b16df7]*/
76+
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=6c78d729dec7bf7e]*/
6477

6578

6679
// _ThreadHandle type
@@ -731,6 +744,99 @@ static PyType_Spec ThreadHandle_Type_spec = {
731744
ThreadHandle_Type_slots,
732745
};
733746

747+
/* iter_locked object **************************************************************/
748+
749+
typedef struct {
750+
PyObject_HEAD
751+
PyObject *it;
752+
} iter_locked_object;
753+
754+
#define iter_locked_object_CAST(op) ((iter_locked_object *)(op))
755+
756+
/*[clinic input]
757+
@classmethod
758+
_thread.iter_locked.__new__
759+
iterable: object
760+
/
761+
Make an iterator thread-safe.
762+
[clinic start generated code]*/
763+
764+
static PyObject *
765+
_thread_iter_locked_impl(PyTypeObject *type, PyObject *iterable)
766+
/*[clinic end generated code: output=4a8ad5a25f7c09ba input=ae6124177726e809]*/
767+
{
768+
/* Get iterator. */
769+
PyObject *it = PyObject_GetIter(iterable);
770+
if (it == NULL)
771+
return NULL;
772+
773+
iter_locked_object *lz = (iter_locked_object *)type->tp_alloc(type, 0);
774+
lz->it = it;
775+
776+
return (PyObject *)lz;
777+
}
778+
779+
static void
780+
iter_locked_dealloc(PyObject *op)
781+
{
782+
iter_locked_object *lz = iter_locked_object_CAST(op);
783+
PyTypeObject *tp = Py_TYPE(lz);
784+
PyObject_GC_UnTrack(lz);
785+
Py_XDECREF(lz->it);
786+
tp->tp_free(lz);
787+
Py_DECREF(tp);
788+
}
789+
790+
static int
791+
iter_locked_traverse(PyObject *op, visitproc visit, void *arg)
792+
{
793+
iter_locked_object *lz = iter_locked_object_CAST(op);
794+
Py_VISIT(Py_TYPE(lz));
795+
Py_VISIT(lz->it);
796+
return 0;
797+
}
798+
799+
static PyObject *
800+
iter_locked_next(PyObject *op)
801+
{
802+
iter_locked_object *lz = iter_locked_object_CAST(op);
803+
PyObject *result = NULL;
804+
805+
Py_BEGIN_CRITICAL_SECTION(lz->it); // or lock on op?
806+
PyObject *it = lz->it;
807+
if (it != NULL) {
808+
result = PyIter_Next(lz->it);
809+
if (result == NULL) {
810+
/* Note: StopIteration is already cleared by PyIter_Next() */
811+
if (PyErr_Occurred())
812+
return NULL;
813+
Py_CLEAR(lz->it);
814+
}
815+
}
816+
Py_END_CRITICAL_SECTION();
817+
return result;
818+
}
819+
820+
static PyType_Slot iter_locked_slots[] = {
821+
{Py_tp_dealloc, iter_locked_dealloc},
822+
{Py_tp_getattro, PyObject_GenericGetAttr},
823+
{Py_tp_doc, (void *)_thread_iter_locked__doc__},
824+
{Py_tp_traverse, iter_locked_traverse},
825+
{Py_tp_iter, PyObject_SelfIter},
826+
{Py_tp_iternext, iter_locked_next},
827+
{Py_tp_new, _thread_iter_locked},
828+
{Py_tp_free, PyObject_GC_Del},
829+
{0, NULL},
830+
};
831+
832+
static PyType_Spec iter_locked_spec = {
833+
.name = "threading.iter_locked",
834+
.basicsize = sizeof(iter_locked_object),
835+
.flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE |
836+
Py_TPFLAGS_IMMUTABLETYPE),
837+
.slots = iter_locked_slots,
838+
};
839+
734840
/* Lock objects */
735841

736842
typedef struct {
@@ -2631,6 +2737,15 @@ thread_module_exec(PyObject *module)
26312737
return -1;
26322738
}
26332739

2740+
// iter_locked
2741+
state->iter_locked_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &iter_locked_spec, NULL);
2742+
if (state->iter_locked_type == NULL) {
2743+
return -1;
2744+
}
2745+
if (PyModule_AddType(module, state->iter_locked_type) < 0) {
2746+
return -1;
2747+
}
2748+
26342749
// Lock
26352750
state->lock_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &lock_type_spec, NULL);
26362751
if (state->lock_type == NULL) {
@@ -2739,6 +2854,7 @@ thread_module_traverse(PyObject *module, visitproc visit, void *arg)
27392854
{
27402855
thread_module_state *state = get_thread_state(module);
27412856
Py_VISIT(state->excepthook_type);
2857+
Py_CLEAR(state->iter_locked_type);
27422858
Py_VISIT(state->lock_type);
27432859
Py_VISIT(state->local_type);
27442860
Py_VISIT(state->local_dummy_type);
@@ -2751,6 +2867,7 @@ thread_module_clear(PyObject *module)
27512867
{
27522868
thread_module_state *state = get_thread_state(module);
27532869
Py_CLEAR(state->excepthook_type);
2870+
Py_CLEAR(state->iter_locked_type);
27542871
Py_CLEAR(state->lock_type);
27552872
Py_CLEAR(state->local_type);
27562873
Py_CLEAR(state->local_dummy_type);

Modules/clinic/_threadmodule.c.h

Lines changed: 32 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)