1"""
2Various tests for synchronization primitives.
3"""
4
5import os
6import gc
7import sys
8import time
9from _thread import start_new_thread, TIMEOUT_MAX
10import threading
11import unittest
12import weakref
13
14from test import support
15from test.support import threading_helper
16
17
18requires_fork = unittest.skipUnless(support.has_fork_support,
19                                    "platform doesn't support fork "
20                                     "(no _at_fork_reinit method)")
21
22
23def _wait():
24    # A crude wait/yield function not relying on synchronization primitives.
25    time.sleep(0.01)
26
27class Bunch(object):
28    """
29    A bunch of threads.
30    """
31    def __init__(self, f, n, wait_before_exit=False):
32        """
33        Construct a bunch of `n` threads running the same function `f`.
34        If `wait_before_exit` is True, the threads won't terminate until
35        do_finish() is called.
36        """
37        self.f = f
38        self.n = n
39        self.started = []
40        self.finished = []
41        self._can_exit = not wait_before_exit
42        self.wait_thread = threading_helper.wait_threads_exit()
43        self.wait_thread.__enter__()
44
45        def task():
46            tid = threading.get_ident()
47            self.started.append(tid)
48            try:
49                f()
50            finally:
51                self.finished.append(tid)
52                while not self._can_exit:
53                    _wait()
54
55        try:
56            for i in range(n):
57                start_new_thread(task, ())
58        except:
59            self._can_exit = True
60            raise
61
62    def wait_for_started(self):
63        while len(self.started) < self.n:
64            _wait()
65
66    def wait_for_finished(self):
67        while len(self.finished) < self.n:
68            _wait()
69        # Wait for threads exit
70        self.wait_thread.__exit__(None, None, None)
71
72    def do_finish(self):
73        self._can_exit = True
74
75
76class BaseTestCase(unittest.TestCase):
77    def setUp(self):
78        self._threads = threading_helper.threading_setup()
79
80    def tearDown(self):
81        threading_helper.threading_cleanup(*self._threads)
82        support.reap_children()
83
84    def assertTimeout(self, actual, expected):
85        # The waiting and/or time.monotonic() can be imprecise, which
86        # is why comparing to the expected value would sometimes fail
87        # (especially under Windows).
88        self.assertGreaterEqual(actual, expected * 0.6)
89        # Test nothing insane happened
90        self.assertLess(actual, expected * 10.0)
91
92
93class BaseLockTests(BaseTestCase):
94    """
95    Tests for both recursive and non-recursive locks.
96    """
97
98    def test_constructor(self):
99        lock = self.locktype()
100        del lock
101
102    def test_repr(self):
103        lock = self.locktype()
104        self.assertRegex(repr(lock), "<unlocked .* object (.*)?at .*>")
105        del lock
106
107    def test_locked_repr(self):
108        lock = self.locktype()
109        lock.acquire()
110        self.assertRegex(repr(lock), "<locked .* object (.*)?at .*>")
111        del lock
112
113    def test_acquire_destroy(self):
114        lock = self.locktype()
115        lock.acquire()
116        del lock
117
118    def test_acquire_release(self):
119        lock = self.locktype()
120        lock.acquire()
121        lock.release()
122        del lock
123
124    def test_try_acquire(self):
125        lock = self.locktype()
126        self.assertTrue(lock.acquire(False))
127        lock.release()
128
129    def test_try_acquire_contended(self):
130        lock = self.locktype()
131        lock.acquire()
132        result = []
133        def f():
134            result.append(lock.acquire(False))
135        Bunch(f, 1).wait_for_finished()
136        self.assertFalse(result[0])
137        lock.release()
138
139    def test_acquire_contended(self):
140        lock = self.locktype()
141        lock.acquire()
142        N = 5
143        def f():
144            lock.acquire()
145            lock.release()
146
147        b = Bunch(f, N)
148        b.wait_for_started()
149        _wait()
150        self.assertEqual(len(b.finished), 0)
151        lock.release()
152        b.wait_for_finished()
153        self.assertEqual(len(b.finished), N)
154
155    def test_with(self):
156        lock = self.locktype()
157        def f():
158            lock.acquire()
159            lock.release()
160        def _with(err=None):
161            with lock:
162                if err is not None:
163                    raise err
164        _with()
165        # Check the lock is unacquired
166        Bunch(f, 1).wait_for_finished()
167        self.assertRaises(TypeError, _with, TypeError)
168        # Check the lock is unacquired
169        Bunch(f, 1).wait_for_finished()
170
171    def test_thread_leak(self):
172        # The lock shouldn't leak a Thread instance when used from a foreign
173        # (non-threading) thread.
174        lock = self.locktype()
175        def f():
176            lock.acquire()
177            lock.release()
178        n = len(threading.enumerate())
179        # We run many threads in the hope that existing threads ids won't
180        # be recycled.
181        Bunch(f, 15).wait_for_finished()
182        if len(threading.enumerate()) != n:
183            # There is a small window during which a Thread instance's
184            # target function has finished running, but the Thread is still
185            # alive and registered.  Avoid spurious failures by waiting a
186            # bit more (seen on a buildbot).
187            time.sleep(0.4)
188            self.assertEqual(n, len(threading.enumerate()))
189
190    def test_timeout(self):
191        lock = self.locktype()
192        # Can't set timeout if not blocking
193        self.assertRaises(ValueError, lock.acquire, False, 1)
194        # Invalid timeout values
195        self.assertRaises(ValueError, lock.acquire, timeout=-100)
196        self.assertRaises(OverflowError, lock.acquire, timeout=1e100)
197        self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1)
198        # TIMEOUT_MAX is ok
199        lock.acquire(timeout=TIMEOUT_MAX)
200        lock.release()
201        t1 = time.monotonic()
202        self.assertTrue(lock.acquire(timeout=5))
203        t2 = time.monotonic()
204        # Just a sanity test that it didn't actually wait for the timeout.
205        self.assertLess(t2 - t1, 5)
206        results = []
207        def f():
208            t1 = time.monotonic()
209            results.append(lock.acquire(timeout=0.5))
210            t2 = time.monotonic()
211            results.append(t2 - t1)
212        Bunch(f, 1).wait_for_finished()
213        self.assertFalse(results[0])
214        self.assertTimeout(results[1], 0.5)
215
216    def test_weakref_exists(self):
217        lock = self.locktype()
218        ref = weakref.ref(lock)
219        self.assertIsNotNone(ref())
220
221    def test_weakref_deleted(self):
222        lock = self.locktype()
223        ref = weakref.ref(lock)
224        del lock
225        gc.collect()  # For PyPy or other GCs.
226        self.assertIsNone(ref())
227
228
229class LockTests(BaseLockTests):
230    """
231    Tests for non-recursive, weak locks
232    (which can be acquired and released from different threads).
233    """
234    def test_reacquire(self):
235        # Lock needs to be released before re-acquiring.
236        lock = self.locktype()
237        phase = []
238
239        def f():
240            lock.acquire()
241            phase.append(None)
242            lock.acquire()
243            phase.append(None)
244
245        with threading_helper.wait_threads_exit():
246            start_new_thread(f, ())
247            while len(phase) == 0:
248                _wait()
249            _wait()
250            self.assertEqual(len(phase), 1)
251            lock.release()
252            while len(phase) == 1:
253                _wait()
254            self.assertEqual(len(phase), 2)
255
256    def test_different_thread(self):
257        # Lock can be released from a different thread.
258        lock = self.locktype()
259        lock.acquire()
260        def f():
261            lock.release()
262        b = Bunch(f, 1)
263        b.wait_for_finished()
264        lock.acquire()
265        lock.release()
266
267    def test_state_after_timeout(self):
268        # Issue #11618: check that lock is in a proper state after a
269        # (non-zero) timeout.
270        lock = self.locktype()
271        lock.acquire()
272        self.assertFalse(lock.acquire(timeout=0.01))
273        lock.release()
274        self.assertFalse(lock.locked())
275        self.assertTrue(lock.acquire(blocking=False))
276
277    @requires_fork
278    def test_at_fork_reinit(self):
279        def use_lock(lock):
280            # make sure that the lock still works normally
281            # after _at_fork_reinit()
282            lock.acquire()
283            lock.release()
284
285        # unlocked
286        lock = self.locktype()
287        lock._at_fork_reinit()
288        use_lock(lock)
289
290        # locked: _at_fork_reinit() resets the lock to the unlocked state
291        lock2 = self.locktype()
292        lock2.acquire()
293        lock2._at_fork_reinit()
294        use_lock(lock2)
295
296
297class RLockTests(BaseLockTests):
298    """
299    Tests for recursive locks.
300    """
301    def test_reacquire(self):
302        lock = self.locktype()
303        lock.acquire()
304        lock.acquire()
305        lock.release()
306        lock.acquire()
307        lock.release()
308        lock.release()
309
310    def test_release_unacquired(self):
311        # Cannot release an unacquired lock
312        lock = self.locktype()
313        self.assertRaises(RuntimeError, lock.release)
314        lock.acquire()
315        lock.acquire()
316        lock.release()
317        lock.acquire()
318        lock.release()
319        lock.release()
320        self.assertRaises(RuntimeError, lock.release)
321
322    def test_release_save_unacquired(self):
323        # Cannot _release_save an unacquired lock
324        lock = self.locktype()
325        self.assertRaises(RuntimeError, lock._release_save)
326        lock.acquire()
327        lock.acquire()
328        lock.release()
329        lock.acquire()
330        lock.release()
331        lock.release()
332        self.assertRaises(RuntimeError, lock._release_save)
333
334    def test_different_thread(self):
335        # Cannot release from a different thread
336        lock = self.locktype()
337        def f():
338            lock.acquire()
339        b = Bunch(f, 1, True)
340        try:
341            self.assertRaises(RuntimeError, lock.release)
342        finally:
343            b.do_finish()
344        b.wait_for_finished()
345
346    def test__is_owned(self):
347        lock = self.locktype()
348        self.assertFalse(lock._is_owned())
349        lock.acquire()
350        self.assertTrue(lock._is_owned())
351        lock.acquire()
352        self.assertTrue(lock._is_owned())
353        result = []
354        def f():
355            result.append(lock._is_owned())
356        Bunch(f, 1).wait_for_finished()
357        self.assertFalse(result[0])
358        lock.release()
359        self.assertTrue(lock._is_owned())
360        lock.release()
361        self.assertFalse(lock._is_owned())
362
363
364class EventTests(BaseTestCase):
365    """
366    Tests for Event objects.
367    """
368
369    def test_is_set(self):
370        evt = self.eventtype()
371        self.assertFalse(evt.is_set())
372        evt.set()
373        self.assertTrue(evt.is_set())
374        evt.set()
375        self.assertTrue(evt.is_set())
376        evt.clear()
377        self.assertFalse(evt.is_set())
378        evt.clear()
379        self.assertFalse(evt.is_set())
380
381    def _check_notify(self, evt):
382        # All threads get notified
383        N = 5
384        results1 = []
385        results2 = []
386        def f():
387            results1.append(evt.wait())
388            results2.append(evt.wait())
389        b = Bunch(f, N)
390        b.wait_for_started()
391        _wait()
392        self.assertEqual(len(results1), 0)
393        evt.set()
394        b.wait_for_finished()
395        self.assertEqual(results1, [True] * N)
396        self.assertEqual(results2, [True] * N)
397
398    def test_notify(self):
399        evt = self.eventtype()
400        self._check_notify(evt)
401        # Another time, after an explicit clear()
402        evt.set()
403        evt.clear()
404        self._check_notify(evt)
405
406    def test_timeout(self):
407        evt = self.eventtype()
408        results1 = []
409        results2 = []
410        N = 5
411        def f():
412            results1.append(evt.wait(0.0))
413            t1 = time.monotonic()
414            r = evt.wait(0.5)
415            t2 = time.monotonic()
416            results2.append((r, t2 - t1))
417        Bunch(f, N).wait_for_finished()
418        self.assertEqual(results1, [False] * N)
419        for r, dt in results2:
420            self.assertFalse(r)
421            self.assertTimeout(dt, 0.5)
422        # The event is set
423        results1 = []
424        results2 = []
425        evt.set()
426        Bunch(f, N).wait_for_finished()
427        self.assertEqual(results1, [True] * N)
428        for r, dt in results2:
429            self.assertTrue(r)
430
431    def test_set_and_clear(self):
432        # Issue #13502: check that wait() returns true even when the event is
433        # cleared before the waiting thread is woken up.
434        evt = self.eventtype()
435        results = []
436        timeout = 0.250
437        N = 5
438        def f():
439            results.append(evt.wait(timeout * 4))
440        b = Bunch(f, N)
441        b.wait_for_started()
442        time.sleep(timeout)
443        evt.set()
444        evt.clear()
445        b.wait_for_finished()
446        self.assertEqual(results, [True] * N)
447
448    @requires_fork
449    def test_at_fork_reinit(self):
450        # ensure that condition is still using a Lock after reset
451        evt = self.eventtype()
452        with evt._cond:
453            self.assertFalse(evt._cond.acquire(False))
454        evt._at_fork_reinit()
455        with evt._cond:
456            self.assertFalse(evt._cond.acquire(False))
457
458    def test_repr(self):
459        evt = self.eventtype()
460        self.assertRegex(repr(evt), r"<\w+\.Event at .*: unset>")
461        evt.set()
462        self.assertRegex(repr(evt), r"<\w+\.Event at .*: set>")
463
464
465class ConditionTests(BaseTestCase):
466    """
467    Tests for condition variables.
468    """
469
470    def test_acquire(self):
471        cond = self.condtype()
472        # Be default we have an RLock: the condition can be acquired multiple
473        # times.
474        cond.acquire()
475        cond.acquire()
476        cond.release()
477        cond.release()
478        lock = threading.Lock()
479        cond = self.condtype(lock)
480        cond.acquire()
481        self.assertFalse(lock.acquire(False))
482        cond.release()
483        self.assertTrue(lock.acquire(False))
484        self.assertFalse(cond.acquire(False))
485        lock.release()
486        with cond:
487            self.assertFalse(lock.acquire(False))
488
489    def test_unacquired_wait(self):
490        cond = self.condtype()
491        self.assertRaises(RuntimeError, cond.wait)
492
493    def test_unacquired_notify(self):
494        cond = self.condtype()
495        self.assertRaises(RuntimeError, cond.notify)
496
497    def _check_notify(self, cond):
498        # Note that this test is sensitive to timing.  If the worker threads
499        # don't execute in a timely fashion, the main thread may think they
500        # are further along then they are.  The main thread therefore issues
501        # _wait() statements to try to make sure that it doesn't race ahead
502        # of the workers.
503        # Secondly, this test assumes that condition variables are not subject
504        # to spurious wakeups.  The absence of spurious wakeups is an implementation
505        # detail of Condition Variables in current CPython, but in general, not
506        # a guaranteed property of condition variables as a programming
507        # construct.  In particular, it is possible that this can no longer
508        # be conveniently guaranteed should their implementation ever change.
509        N = 5
510        ready = []
511        results1 = []
512        results2 = []
513        phase_num = 0
514        def f():
515            cond.acquire()
516            ready.append(phase_num)
517            result = cond.wait()
518            cond.release()
519            results1.append((result, phase_num))
520            cond.acquire()
521            ready.append(phase_num)
522            result = cond.wait()
523            cond.release()
524            results2.append((result, phase_num))
525        b = Bunch(f, N)
526        b.wait_for_started()
527        # first wait, to ensure all workers settle into cond.wait() before
528        # we continue. See issues #8799 and #30727.
529        while len(ready) < 5:
530            _wait()
531        ready.clear()
532        self.assertEqual(results1, [])
533        # Notify 3 threads at first
534        cond.acquire()
535        cond.notify(3)
536        _wait()
537        phase_num = 1
538        cond.release()
539        while len(results1) < 3:
540            _wait()
541        self.assertEqual(results1, [(True, 1)] * 3)
542        self.assertEqual(results2, [])
543        # make sure all awaken workers settle into cond.wait()
544        while len(ready) < 3:
545            _wait()
546        # Notify 5 threads: they might be in their first or second wait
547        cond.acquire()
548        cond.notify(5)
549        _wait()
550        phase_num = 2
551        cond.release()
552        while len(results1) + len(results2) < 8:
553            _wait()
554        self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2)
555        self.assertEqual(results2, [(True, 2)] * 3)
556        # make sure all workers settle into cond.wait()
557        while len(ready) < 5:
558            _wait()
559        # Notify all threads: they are all in their second wait
560        cond.acquire()
561        cond.notify_all()
562        _wait()
563        phase_num = 3
564        cond.release()
565        while len(results2) < 5:
566            _wait()
567        self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2)
568        self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2)
569        b.wait_for_finished()
570
571    def test_notify(self):
572        cond = self.condtype()
573        self._check_notify(cond)
574        # A second time, to check internal state is still ok.
575        self._check_notify(cond)
576
577    def test_timeout(self):
578        cond = self.condtype()
579        results = []
580        N = 5
581        def f():
582            cond.acquire()
583            t1 = time.monotonic()
584            result = cond.wait(0.5)
585            t2 = time.monotonic()
586            cond.release()
587            results.append((t2 - t1, result))
588        Bunch(f, N).wait_for_finished()
589        self.assertEqual(len(results), N)
590        for dt, result in results:
591            self.assertTimeout(dt, 0.5)
592            # Note that conceptually (that"s the condition variable protocol)
593            # a wait() may succeed even if no one notifies us and before any
594            # timeout occurs.  Spurious wakeups can occur.
595            # This makes it hard to verify the result value.
596            # In practice, this implementation has no spurious wakeups.
597            self.assertFalse(result)
598
599    def test_waitfor(self):
600        cond = self.condtype()
601        state = 0
602        def f():
603            with cond:
604                result = cond.wait_for(lambda : state==4)
605                self.assertTrue(result)
606                self.assertEqual(state, 4)
607        b = Bunch(f, 1)
608        b.wait_for_started()
609        for i in range(4):
610            time.sleep(0.01)
611            with cond:
612                state += 1
613                cond.notify()
614        b.wait_for_finished()
615
616    def test_waitfor_timeout(self):
617        cond = self.condtype()
618        state = 0
619        success = []
620        def f():
621            with cond:
622                dt = time.monotonic()
623                result = cond.wait_for(lambda : state==4, timeout=0.1)
624                dt = time.monotonic() - dt
625                self.assertFalse(result)
626                self.assertTimeout(dt, 0.1)
627                success.append(None)
628        b = Bunch(f, 1)
629        b.wait_for_started()
630        # Only increment 3 times, so state == 4 is never reached.
631        for i in range(3):
632            time.sleep(0.01)
633            with cond:
634                state += 1
635                cond.notify()
636        b.wait_for_finished()
637        self.assertEqual(len(success), 1)
638
639
640class BaseSemaphoreTests(BaseTestCase):
641    """
642    Common tests for {bounded, unbounded} semaphore objects.
643    """
644
645    def test_constructor(self):
646        self.assertRaises(ValueError, self.semtype, value = -1)
647        self.assertRaises(ValueError, self.semtype, value = -sys.maxsize)
648
649    def test_acquire(self):
650        sem = self.semtype(1)
651        sem.acquire()
652        sem.release()
653        sem = self.semtype(2)
654        sem.acquire()
655        sem.acquire()
656        sem.release()
657        sem.release()
658
659    def test_acquire_destroy(self):
660        sem = self.semtype()
661        sem.acquire()
662        del sem
663
664    def test_acquire_contended(self):
665        sem = self.semtype(7)
666        sem.acquire()
667        N = 10
668        sem_results = []
669        results1 = []
670        results2 = []
671        phase_num = 0
672        def f():
673            sem_results.append(sem.acquire())
674            results1.append(phase_num)
675            sem_results.append(sem.acquire())
676            results2.append(phase_num)
677        b = Bunch(f, 10)
678        b.wait_for_started()
679        while len(results1) + len(results2) < 6:
680            _wait()
681        self.assertEqual(results1 + results2, [0] * 6)
682        phase_num = 1
683        for i in range(7):
684            sem.release()
685        while len(results1) + len(results2) < 13:
686            _wait()
687        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
688        phase_num = 2
689        for i in range(6):
690            sem.release()
691        while len(results1) + len(results2) < 19:
692            _wait()
693        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
694        # The semaphore is still locked
695        self.assertFalse(sem.acquire(False))
696        # Final release, to let the last thread finish
697        sem.release()
698        b.wait_for_finished()
699        self.assertEqual(sem_results, [True] * (6 + 7 + 6 + 1))
700
701    def test_multirelease(self):
702        sem = self.semtype(7)
703        sem.acquire()
704        results1 = []
705        results2 = []
706        phase_num = 0
707        def f():
708            sem.acquire()
709            results1.append(phase_num)
710            sem.acquire()
711            results2.append(phase_num)
712        b = Bunch(f, 10)
713        b.wait_for_started()
714        while len(results1) + len(results2) < 6:
715            _wait()
716        self.assertEqual(results1 + results2, [0] * 6)
717        phase_num = 1
718        sem.release(7)
719        while len(results1) + len(results2) < 13:
720            _wait()
721        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
722        phase_num = 2
723        sem.release(6)
724        while len(results1) + len(results2) < 19:
725            _wait()
726        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
727        # The semaphore is still locked
728        self.assertFalse(sem.acquire(False))
729        # Final release, to let the last thread finish
730        sem.release()
731        b.wait_for_finished()
732
733    def test_try_acquire(self):
734        sem = self.semtype(2)
735        self.assertTrue(sem.acquire(False))
736        self.assertTrue(sem.acquire(False))
737        self.assertFalse(sem.acquire(False))
738        sem.release()
739        self.assertTrue(sem.acquire(False))
740
741    def test_try_acquire_contended(self):
742        sem = self.semtype(4)
743        sem.acquire()
744        results = []
745        def f():
746            results.append(sem.acquire(False))
747            results.append(sem.acquire(False))
748        Bunch(f, 5).wait_for_finished()
749        # There can be a thread switch between acquiring the semaphore and
750        # appending the result, therefore results will not necessarily be
751        # ordered.
752        self.assertEqual(sorted(results), [False] * 7 + [True] *  3 )
753
754    def test_acquire_timeout(self):
755        sem = self.semtype(2)
756        self.assertRaises(ValueError, sem.acquire, False, timeout=1.0)
757        self.assertTrue(sem.acquire(timeout=0.005))
758        self.assertTrue(sem.acquire(timeout=0.005))
759        self.assertFalse(sem.acquire(timeout=0.005))
760        sem.release()
761        self.assertTrue(sem.acquire(timeout=0.005))
762        t = time.monotonic()
763        self.assertFalse(sem.acquire(timeout=0.5))
764        dt = time.monotonic() - t
765        self.assertTimeout(dt, 0.5)
766
767    def test_default_value(self):
768        # The default initial value is 1.
769        sem = self.semtype()
770        sem.acquire()
771        def f():
772            sem.acquire()
773            sem.release()
774        b = Bunch(f, 1)
775        b.wait_for_started()
776        _wait()
777        self.assertFalse(b.finished)
778        sem.release()
779        b.wait_for_finished()
780
781    def test_with(self):
782        sem = self.semtype(2)
783        def _with(err=None):
784            with sem:
785                self.assertTrue(sem.acquire(False))
786                sem.release()
787                with sem:
788                    self.assertFalse(sem.acquire(False))
789                    if err:
790                        raise err
791        _with()
792        self.assertTrue(sem.acquire(False))
793        sem.release()
794        self.assertRaises(TypeError, _with, TypeError)
795        self.assertTrue(sem.acquire(False))
796        sem.release()
797
798class SemaphoreTests(BaseSemaphoreTests):
799    """
800    Tests for unbounded semaphores.
801    """
802
803    def test_release_unacquired(self):
804        # Unbounded releases are allowed and increment the semaphore's value
805        sem = self.semtype(1)
806        sem.release()
807        sem.acquire()
808        sem.acquire()
809        sem.release()
810
811    def test_repr(self):
812        sem = self.semtype(3)
813        self.assertRegex(repr(sem), r"<\w+\.Semaphore at .*: value=3>")
814        sem.acquire()
815        self.assertRegex(repr(sem), r"<\w+\.Semaphore at .*: value=2>")
816        sem.release()
817        sem.release()
818        self.assertRegex(repr(sem), r"<\w+\.Semaphore at .*: value=4>")
819
820
821class BoundedSemaphoreTests(BaseSemaphoreTests):
822    """
823    Tests for bounded semaphores.
824    """
825
826    def test_release_unacquired(self):
827        # Cannot go past the initial value
828        sem = self.semtype()
829        self.assertRaises(ValueError, sem.release)
830        sem.acquire()
831        sem.release()
832        self.assertRaises(ValueError, sem.release)
833
834    def test_repr(self):
835        sem = self.semtype(3)
836        self.assertRegex(repr(sem), r"<\w+\.BoundedSemaphore at .*: value=3/3>")
837        sem.acquire()
838        self.assertRegex(repr(sem), r"<\w+\.BoundedSemaphore at .*: value=2/3>")
839
840
841class BarrierTests(BaseTestCase):
842    """
843    Tests for Barrier objects.
844    """
845    N = 5
846    defaultTimeout = 2.0
847
848    def setUp(self):
849        self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout)
850    def tearDown(self):
851        self.barrier.abort()
852
853    def run_threads(self, f):
854        b = Bunch(f, self.N-1)
855        f()
856        b.wait_for_finished()
857
858    def multipass(self, results, n):
859        m = self.barrier.parties
860        self.assertEqual(m, self.N)
861        for i in range(n):
862            results[0].append(True)
863            self.assertEqual(len(results[1]), i * m)
864            self.barrier.wait()
865            results[1].append(True)
866            self.assertEqual(len(results[0]), (i + 1) * m)
867            self.barrier.wait()
868        self.assertEqual(self.barrier.n_waiting, 0)
869        self.assertFalse(self.barrier.broken)
870
871    def test_barrier(self, passes=1):
872        """
873        Test that a barrier is passed in lockstep
874        """
875        results = [[],[]]
876        def f():
877            self.multipass(results, passes)
878        self.run_threads(f)
879
880    def test_barrier_10(self):
881        """
882        Test that a barrier works for 10 consecutive runs
883        """
884        return self.test_barrier(10)
885
886    def test_wait_return(self):
887        """
888        test the return value from barrier.wait
889        """
890        results = []
891        def f():
892            r = self.barrier.wait()
893            results.append(r)
894
895        self.run_threads(f)
896        self.assertEqual(sum(results), sum(range(self.N)))
897
898    def test_action(self):
899        """
900        Test the 'action' callback
901        """
902        results = []
903        def action():
904            results.append(True)
905        barrier = self.barriertype(self.N, action)
906        def f():
907            barrier.wait()
908            self.assertEqual(len(results), 1)
909
910        self.run_threads(f)
911
912    def test_abort(self):
913        """
914        Test that an abort will put the barrier in a broken state
915        """
916        results1 = []
917        results2 = []
918        def f():
919            try:
920                i = self.barrier.wait()
921                if i == self.N//2:
922                    raise RuntimeError
923                self.barrier.wait()
924                results1.append(True)
925            except threading.BrokenBarrierError:
926                results2.append(True)
927            except RuntimeError:
928                self.barrier.abort()
929                pass
930
931        self.run_threads(f)
932        self.assertEqual(len(results1), 0)
933        self.assertEqual(len(results2), self.N-1)
934        self.assertTrue(self.barrier.broken)
935
936    def test_reset(self):
937        """
938        Test that a 'reset' on a barrier frees the waiting threads
939        """
940        results1 = []
941        results2 = []
942        results3 = []
943        def f():
944            i = self.barrier.wait()
945            if i == self.N//2:
946                # Wait until the other threads are all in the barrier.
947                while self.barrier.n_waiting < self.N-1:
948                    time.sleep(0.001)
949                self.barrier.reset()
950            else:
951                try:
952                    self.barrier.wait()
953                    results1.append(True)
954                except threading.BrokenBarrierError:
955                    results2.append(True)
956            # Now, pass the barrier again
957            self.barrier.wait()
958            results3.append(True)
959
960        self.run_threads(f)
961        self.assertEqual(len(results1), 0)
962        self.assertEqual(len(results2), self.N-1)
963        self.assertEqual(len(results3), self.N)
964
965
966    def test_abort_and_reset(self):
967        """
968        Test that a barrier can be reset after being broken.
969        """
970        results1 = []
971        results2 = []
972        results3 = []
973        barrier2 = self.barriertype(self.N)
974        def f():
975            try:
976                i = self.barrier.wait()
977                if i == self.N//2:
978                    raise RuntimeError
979                self.barrier.wait()
980                results1.append(True)
981            except threading.BrokenBarrierError:
982                results2.append(True)
983            except RuntimeError:
984                self.barrier.abort()
985                pass
986            # Synchronize and reset the barrier.  Must synchronize first so
987            # that everyone has left it when we reset, and after so that no
988            # one enters it before the reset.
989            if barrier2.wait() == self.N//2:
990                self.barrier.reset()
991            barrier2.wait()
992            self.barrier.wait()
993            results3.append(True)
994
995        self.run_threads(f)
996        self.assertEqual(len(results1), 0)
997        self.assertEqual(len(results2), self.N-1)
998        self.assertEqual(len(results3), self.N)
999
1000    def test_timeout(self):
1001        """
1002        Test wait(timeout)
1003        """
1004        def f():
1005            i = self.barrier.wait()
1006            if i == self.N // 2:
1007                # One thread is late!
1008                time.sleep(1.0)
1009            # Default timeout is 2.0, so this is shorter.
1010            self.assertRaises(threading.BrokenBarrierError,
1011                              self.barrier.wait, 0.5)
1012        self.run_threads(f)
1013
1014    def test_default_timeout(self):
1015        """
1016        Test the barrier's default timeout
1017        """
1018        # create a barrier with a low default timeout
1019        barrier = self.barriertype(self.N, timeout=0.3)
1020        def f():
1021            i = barrier.wait()
1022            if i == self.N // 2:
1023                # One thread is later than the default timeout of 0.3s.
1024                time.sleep(1.0)
1025            self.assertRaises(threading.BrokenBarrierError, barrier.wait)
1026        self.run_threads(f)
1027
1028    def test_single_thread(self):
1029        b = self.barriertype(1)
1030        b.wait()
1031        b.wait()
1032
1033    def test_repr(self):
1034        b = self.barriertype(3)
1035        self.assertRegex(repr(b), r"<\w+\.Barrier at .*: waiters=0/3>")
1036        def f():
1037            b.wait(3)
1038        bunch = Bunch(f, 2)
1039        bunch.wait_for_started()
1040        time.sleep(0.2)
1041        self.assertRegex(repr(b), r"<\w+\.Barrier at .*: waiters=2/3>")
1042        b.wait(3)
1043        bunch.wait_for_finished()
1044        self.assertRegex(repr(b), r"<\w+\.Barrier at .*: waiters=0/3>")
1045        b.abort()
1046        self.assertRegex(repr(b), r"<\w+\.Barrier at .*: broken>")
1047