1import _thread
2import contextlib
3import functools
4import sys
5import threading
6import time
7import unittest
8
9from test import support
10
11
12#=======================================================================
13# Threading support to prevent reporting refleaks when running regrtest.py -R
14
15# NOTE: we use thread._count() rather than threading.enumerate() (or the
16# moral equivalent thereof) because a threading.Thread object is still alive
17# until its __bootstrap() method has returned, even after it has been
18# unregistered from the threading module.
19# thread._count(), on the other hand, only gets decremented *after* the
20# __bootstrap() method has returned, which gives us reliable reference counts
21# at the end of a test run.
22
23
24def threading_setup():
25    return _thread._count(), threading._dangling.copy()
26
27
28def threading_cleanup(*original_values):
29    _MAX_COUNT = 100
30
31    for count in range(_MAX_COUNT):
32        values = _thread._count(), threading._dangling
33        if values == original_values:
34            break
35
36        if not count:
37            # Display a warning at the first iteration
38            support.environment_altered = True
39            dangling_threads = values[1]
40            support.print_warning(f"threading_cleanup() failed to cleanup "
41                                  f"{values[0] - original_values[0]} threads "
42                                  f"(count: {values[0]}, "
43                                  f"dangling: {len(dangling_threads)})")
44            for thread in dangling_threads:
45                support.print_warning(f"Dangling thread: {thread!r}")
46
47            # Don't hold references to threads
48            dangling_threads = None
49        values = None
50
51        time.sleep(0.01)
52        support.gc_collect()
53
54
55def reap_threads(func):
56    """Use this function when threads are being used.  This will
57    ensure that the threads are cleaned up even when the test fails.
58    """
59    @functools.wraps(func)
60    def decorator(*args):
61        key = threading_setup()
62        try:
63            return func(*args)
64        finally:
65            threading_cleanup(*key)
66    return decorator
67
68
69@contextlib.contextmanager
70def wait_threads_exit(timeout=None):
71    """
72    bpo-31234: Context manager to wait until all threads created in the with
73    statement exit.
74
75    Use _thread.count() to check if threads exited. Indirectly, wait until
76    threads exit the internal t_bootstrap() C function of the _thread module.
77
78    threading_setup() and threading_cleanup() are designed to emit a warning
79    if a test leaves running threads in the background. This context manager
80    is designed to cleanup threads started by the _thread.start_new_thread()
81    which doesn't allow to wait for thread exit, whereas thread.Thread has a
82    join() method.
83    """
84    if timeout is None:
85        timeout = support.SHORT_TIMEOUT
86    old_count = _thread._count()
87    try:
88        yield
89    finally:
90        start_time = time.monotonic()
91        deadline = start_time + timeout
92        while True:
93            count = _thread._count()
94            if count <= old_count:
95                break
96            if time.monotonic() > deadline:
97                dt = time.monotonic() - start_time
98                msg = (f"wait_threads() failed to cleanup {count - old_count} "
99                       f"threads after {dt:.1f} seconds "
100                       f"(count: {count}, old count: {old_count})")
101                raise AssertionError(msg)
102            time.sleep(0.010)
103            support.gc_collect()
104
105
106def join_thread(thread, timeout=None):
107    """Join a thread. Raise an AssertionError if the thread is still alive
108    after timeout seconds.
109    """
110    if timeout is None:
111        timeout = support.SHORT_TIMEOUT
112    thread.join(timeout)
113    if thread.is_alive():
114        msg = f"failed to join the thread in {timeout:.1f} seconds"
115        raise AssertionError(msg)
116
117
118@contextlib.contextmanager
119def start_threads(threads, unlock=None):
120    import faulthandler
121    threads = list(threads)
122    started = []
123    try:
124        try:
125            for t in threads:
126                t.start()
127                started.append(t)
128        except:
129            if support.verbose:
130                print("Can't start %d threads, only %d threads started" %
131                      (len(threads), len(started)))
132            raise
133        yield
134    finally:
135        try:
136            if unlock:
137                unlock()
138            endtime = time.monotonic()
139            for timeout in range(1, 16):
140                endtime += 60
141                for t in started:
142                    t.join(max(endtime - time.monotonic(), 0.01))
143                started = [t for t in started if t.is_alive()]
144                if not started:
145                    break
146                if support.verbose:
147                    print('Unable to join %d threads during a period of '
148                          '%d minutes' % (len(started), timeout))
149        finally:
150            started = [t for t in started if t.is_alive()]
151            if started:
152                faulthandler.dump_traceback(sys.stdout)
153                raise AssertionError('Unable to join %d threads' % len(started))
154
155
156class catch_threading_exception:
157    """
158    Context manager catching threading.Thread exception using
159    threading.excepthook.
160
161    Attributes set when an exception is caught:
162
163    * exc_type
164    * exc_value
165    * exc_traceback
166    * thread
167
168    See threading.excepthook() documentation for these attributes.
169
170    These attributes are deleted at the context manager exit.
171
172    Usage:
173
174        with threading_helper.catch_threading_exception() as cm:
175            # code spawning a thread which raises an exception
176            ...
177
178            # check the thread exception, use cm attributes:
179            # exc_type, exc_value, exc_traceback, thread
180            ...
181
182        # exc_type, exc_value, exc_traceback, thread attributes of cm no longer
183        # exists at this point
184        # (to avoid reference cycles)
185    """
186
187    def __init__(self):
188        self.exc_type = None
189        self.exc_value = None
190        self.exc_traceback = None
191        self.thread = None
192        self._old_hook = None
193
194    def _hook(self, args):
195        self.exc_type = args.exc_type
196        self.exc_value = args.exc_value
197        self.exc_traceback = args.exc_traceback
198        self.thread = args.thread
199
200    def __enter__(self):
201        self._old_hook = threading.excepthook
202        threading.excepthook = self._hook
203        return self
204
205    def __exit__(self, *exc_info):
206        threading.excepthook = self._old_hook
207        del self.exc_type
208        del self.exc_value
209        del self.exc_traceback
210        del self.thread
211
212
213def _can_start_thread() -> bool:
214    """Detect whether Python can start new threads.
215
216    Some WebAssembly platforms do not provide a working pthread
217    implementation. Thread support is stubbed and any attempt
218    to create a new thread fails.
219
220    - wasm32-wasi does not have threading.
221    - wasm32-emscripten can be compiled with or without pthread
222      support (-s USE_PTHREADS / __EMSCRIPTEN_PTHREADS__).
223    """
224    if sys.platform == "emscripten":
225        return sys._emscripten_info.pthreads
226    elif sys.platform == "wasi":
227        return False
228    else:
229        # assume all other platforms have working thread support.
230        return True
231
232can_start_thread = _can_start_thread()
233
234def requires_working_threading(*, module=False):
235    """Skip tests or modules that require working threading.
236
237    Can be used as a function/class decorator or to skip an entire module.
238    """
239    msg = "requires threading support"
240    if module:
241        if not can_start_thread:
242            raise unittest.SkipTest(msg)
243    else:
244        return unittest.skipUnless(can_start_thread, msg)
245