1"""Unit tests for contextlib.py, and other context managers."""
2
3import io
4import os
5import sys
6import tempfile
7import threading
8import traceback
9import unittest
10from contextlib import *  # Tests __all__
11from test import support
12from test.support import os_helper
13import weakref
14
15
16class TestAbstractContextManager(unittest.TestCase):
17
18    def test_enter(self):
19        class DefaultEnter(AbstractContextManager):
20            def __exit__(self, *args):
21                super().__exit__(*args)
22
23        manager = DefaultEnter()
24        self.assertIs(manager.__enter__(), manager)
25
26    def test_exit_is_abstract(self):
27        class MissingExit(AbstractContextManager):
28            pass
29
30        with self.assertRaises(TypeError):
31            MissingExit()
32
33    def test_structural_subclassing(self):
34        class ManagerFromScratch:
35            def __enter__(self):
36                return self
37            def __exit__(self, exc_type, exc_value, traceback):
38                return None
39
40        self.assertTrue(issubclass(ManagerFromScratch, AbstractContextManager))
41
42        class DefaultEnter(AbstractContextManager):
43            def __exit__(self, *args):
44                super().__exit__(*args)
45
46        self.assertTrue(issubclass(DefaultEnter, AbstractContextManager))
47
48        class NoEnter(ManagerFromScratch):
49            __enter__ = None
50
51        self.assertFalse(issubclass(NoEnter, AbstractContextManager))
52
53        class NoExit(ManagerFromScratch):
54            __exit__ = None
55
56        self.assertFalse(issubclass(NoExit, AbstractContextManager))
57
58
59class ContextManagerTestCase(unittest.TestCase):
60
61    def test_contextmanager_plain(self):
62        state = []
63        @contextmanager
64        def woohoo():
65            state.append(1)
66            yield 42
67            state.append(999)
68        with woohoo() as x:
69            self.assertEqual(state, [1])
70            self.assertEqual(x, 42)
71            state.append(x)
72        self.assertEqual(state, [1, 42, 999])
73
74    def test_contextmanager_finally(self):
75        state = []
76        @contextmanager
77        def woohoo():
78            state.append(1)
79            try:
80                yield 42
81            finally:
82                state.append(999)
83        with self.assertRaises(ZeroDivisionError):
84            with woohoo() as x:
85                self.assertEqual(state, [1])
86                self.assertEqual(x, 42)
87                state.append(x)
88                raise ZeroDivisionError()
89        self.assertEqual(state, [1, 42, 999])
90
91    def test_contextmanager_traceback(self):
92        @contextmanager
93        def f():
94            yield
95
96        try:
97            with f():
98                1/0
99        except ZeroDivisionError as e:
100            frames = traceback.extract_tb(e.__traceback__)
101
102        self.assertEqual(len(frames), 1)
103        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
104        self.assertEqual(frames[0].line, '1/0')
105
106        # Repeat with RuntimeError (which goes through a different code path)
107        class RuntimeErrorSubclass(RuntimeError):
108            pass
109
110        try:
111            with f():
112                raise RuntimeErrorSubclass(42)
113        except RuntimeErrorSubclass as e:
114            frames = traceback.extract_tb(e.__traceback__)
115
116        self.assertEqual(len(frames), 1)
117        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
118        self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
119
120        class StopIterationSubclass(StopIteration):
121            pass
122
123        for stop_exc in (
124            StopIteration('spam'),
125            StopIterationSubclass('spam'),
126        ):
127            with self.subTest(type=type(stop_exc)):
128                try:
129                    with f():
130                        raise stop_exc
131                except type(stop_exc) as e:
132                    self.assertIs(e, stop_exc)
133                    frames = traceback.extract_tb(e.__traceback__)
134                else:
135                    self.fail(f'{stop_exc} was suppressed')
136
137                self.assertEqual(len(frames), 1)
138                self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
139                self.assertEqual(frames[0].line, 'raise stop_exc')
140
141    def test_contextmanager_no_reraise(self):
142        @contextmanager
143        def whee():
144            yield
145        ctx = whee()
146        ctx.__enter__()
147        # Calling __exit__ should not result in an exception
148        self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
149
150    def test_contextmanager_trap_yield_after_throw(self):
151        @contextmanager
152        def whoo():
153            try:
154                yield
155            except:
156                yield
157        ctx = whoo()
158        ctx.__enter__()
159        self.assertRaises(
160            RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
161        )
162
163    def test_contextmanager_except(self):
164        state = []
165        @contextmanager
166        def woohoo():
167            state.append(1)
168            try:
169                yield 42
170            except ZeroDivisionError as e:
171                state.append(e.args[0])
172                self.assertEqual(state, [1, 42, 999])
173        with woohoo() as x:
174            self.assertEqual(state, [1])
175            self.assertEqual(x, 42)
176            state.append(x)
177            raise ZeroDivisionError(999)
178        self.assertEqual(state, [1, 42, 999])
179
180    def test_contextmanager_except_stopiter(self):
181        @contextmanager
182        def woohoo():
183            yield
184
185        class StopIterationSubclass(StopIteration):
186            pass
187
188        for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
189            with self.subTest(type=type(stop_exc)):
190                try:
191                    with woohoo():
192                        raise stop_exc
193                except Exception as ex:
194                    self.assertIs(ex, stop_exc)
195                else:
196                    self.fail(f'{stop_exc} was suppressed')
197
198    def test_contextmanager_except_pep479(self):
199        code = """\
200from __future__ import generator_stop
201from contextlib import contextmanager
202@contextmanager
203def woohoo():
204    yield
205"""
206        locals = {}
207        exec(code, locals, locals)
208        woohoo = locals['woohoo']
209
210        stop_exc = StopIteration('spam')
211        try:
212            with woohoo():
213                raise stop_exc
214        except Exception as ex:
215            self.assertIs(ex, stop_exc)
216        else:
217            self.fail('StopIteration was suppressed')
218
219    def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self):
220        @contextmanager
221        def test_issue29692():
222            try:
223                yield
224            except Exception as exc:
225                raise RuntimeError('issue29692:Chained') from exc
226        try:
227            with test_issue29692():
228                raise ZeroDivisionError
229        except Exception as ex:
230            self.assertIs(type(ex), RuntimeError)
231            self.assertEqual(ex.args[0], 'issue29692:Chained')
232            self.assertIsInstance(ex.__cause__, ZeroDivisionError)
233
234        try:
235            with test_issue29692():
236                raise StopIteration('issue29692:Unchained')
237        except Exception as ex:
238            self.assertIs(type(ex), StopIteration)
239            self.assertEqual(ex.args[0], 'issue29692:Unchained')
240            self.assertIsNone(ex.__cause__)
241
242    def _create_contextmanager_attribs(self):
243        def attribs(**kw):
244            def decorate(func):
245                for k,v in kw.items():
246                    setattr(func,k,v)
247                return func
248            return decorate
249        @contextmanager
250        @attribs(foo='bar')
251        def baz(spam):
252            """Whee!"""
253        return baz
254
255    def test_contextmanager_attribs(self):
256        baz = self._create_contextmanager_attribs()
257        self.assertEqual(baz.__name__,'baz')
258        self.assertEqual(baz.foo, 'bar')
259
260    @support.requires_docstrings
261    def test_contextmanager_doc_attrib(self):
262        baz = self._create_contextmanager_attribs()
263        self.assertEqual(baz.__doc__, "Whee!")
264
265    @support.requires_docstrings
266    def test_instance_docstring_given_cm_docstring(self):
267        baz = self._create_contextmanager_attribs()(None)
268        self.assertEqual(baz.__doc__, "Whee!")
269
270    def test_keywords(self):
271        # Ensure no keyword arguments are inhibited
272        @contextmanager
273        def woohoo(self, func, args, kwds):
274            yield (self, func, args, kwds)
275        with woohoo(self=11, func=22, args=33, kwds=44) as target:
276            self.assertEqual(target, (11, 22, 33, 44))
277
278    def test_nokeepref(self):
279        class A:
280            pass
281
282        @contextmanager
283        def woohoo(a, b):
284            a = weakref.ref(a)
285            b = weakref.ref(b)
286            # Allow test to work with a non-refcounted GC
287            support.gc_collect()
288            self.assertIsNone(a())
289            self.assertIsNone(b())
290            yield
291
292        with woohoo(A(), b=A()):
293            pass
294
295    def test_param_errors(self):
296        @contextmanager
297        def woohoo(a, *, b):
298            yield
299
300        with self.assertRaises(TypeError):
301            woohoo()
302        with self.assertRaises(TypeError):
303            woohoo(3, 5)
304        with self.assertRaises(TypeError):
305            woohoo(b=3)
306
307    def test_recursive(self):
308        depth = 0
309        @contextmanager
310        def woohoo():
311            nonlocal depth
312            before = depth
313            depth += 1
314            yield
315            depth -= 1
316            self.assertEqual(depth, before)
317
318        @woohoo()
319        def recursive():
320            if depth < 10:
321                recursive()
322
323        recursive()
324        self.assertEqual(depth, 0)
325
326
327class ClosingTestCase(unittest.TestCase):
328
329    @support.requires_docstrings
330    def test_instance_docs(self):
331        # Issue 19330: ensure context manager instances have good docstrings
332        cm_docstring = closing.__doc__
333        obj = closing(None)
334        self.assertEqual(obj.__doc__, cm_docstring)
335
336    def test_closing(self):
337        state = []
338        class C:
339            def close(self):
340                state.append(1)
341        x = C()
342        self.assertEqual(state, [])
343        with closing(x) as y:
344            self.assertEqual(x, y)
345        self.assertEqual(state, [1])
346
347    def test_closing_error(self):
348        state = []
349        class C:
350            def close(self):
351                state.append(1)
352        x = C()
353        self.assertEqual(state, [])
354        with self.assertRaises(ZeroDivisionError):
355            with closing(x) as y:
356                self.assertEqual(x, y)
357                1 / 0
358        self.assertEqual(state, [1])
359
360
361class NullcontextTestCase(unittest.TestCase):
362    def test_nullcontext(self):
363        class C:
364            pass
365        c = C()
366        with nullcontext(c) as c_in:
367            self.assertIs(c_in, c)
368
369
370class FileContextTestCase(unittest.TestCase):
371
372    def testWithOpen(self):
373        tfn = tempfile.mktemp()
374        try:
375            f = None
376            with open(tfn, "w", encoding="utf-8") as f:
377                self.assertFalse(f.closed)
378                f.write("Booh\n")
379            self.assertTrue(f.closed)
380            f = None
381            with self.assertRaises(ZeroDivisionError):
382                with open(tfn, "r", encoding="utf-8") as f:
383                    self.assertFalse(f.closed)
384                    self.assertEqual(f.read(), "Booh\n")
385                    1 / 0
386            self.assertTrue(f.closed)
387        finally:
388            os_helper.unlink(tfn)
389
390class LockContextTestCase(unittest.TestCase):
391
392    def boilerPlate(self, lock, locked):
393        self.assertFalse(locked())
394        with lock:
395            self.assertTrue(locked())
396        self.assertFalse(locked())
397        with self.assertRaises(ZeroDivisionError):
398            with lock:
399                self.assertTrue(locked())
400                1 / 0
401        self.assertFalse(locked())
402
403    def testWithLock(self):
404        lock = threading.Lock()
405        self.boilerPlate(lock, lock.locked)
406
407    def testWithRLock(self):
408        lock = threading.RLock()
409        self.boilerPlate(lock, lock._is_owned)
410
411    def testWithCondition(self):
412        lock = threading.Condition()
413        def locked():
414            return lock._is_owned()
415        self.boilerPlate(lock, locked)
416
417    def testWithSemaphore(self):
418        lock = threading.Semaphore()
419        def locked():
420            if lock.acquire(False):
421                lock.release()
422                return False
423            else:
424                return True
425        self.boilerPlate(lock, locked)
426
427    def testWithBoundedSemaphore(self):
428        lock = threading.BoundedSemaphore()
429        def locked():
430            if lock.acquire(False):
431                lock.release()
432                return False
433            else:
434                return True
435        self.boilerPlate(lock, locked)
436
437
438class mycontext(ContextDecorator):
439    """Example decoration-compatible context manager for testing"""
440    started = False
441    exc = None
442    catch = False
443
444    def __enter__(self):
445        self.started = True
446        return self
447
448    def __exit__(self, *exc):
449        self.exc = exc
450        return self.catch
451
452
453class TestContextDecorator(unittest.TestCase):
454
455    @support.requires_docstrings
456    def test_instance_docs(self):
457        # Issue 19330: ensure context manager instances have good docstrings
458        cm_docstring = mycontext.__doc__
459        obj = mycontext()
460        self.assertEqual(obj.__doc__, cm_docstring)
461
462    def test_contextdecorator(self):
463        context = mycontext()
464        with context as result:
465            self.assertIs(result, context)
466            self.assertTrue(context.started)
467
468        self.assertEqual(context.exc, (None, None, None))
469
470
471    def test_contextdecorator_with_exception(self):
472        context = mycontext()
473
474        with self.assertRaisesRegex(NameError, 'foo'):
475            with context:
476                raise NameError('foo')
477        self.assertIsNotNone(context.exc)
478        self.assertIs(context.exc[0], NameError)
479
480        context = mycontext()
481        context.catch = True
482        with context:
483            raise NameError('foo')
484        self.assertIsNotNone(context.exc)
485        self.assertIs(context.exc[0], NameError)
486
487
488    def test_decorator(self):
489        context = mycontext()
490
491        @context
492        def test():
493            self.assertIsNone(context.exc)
494            self.assertTrue(context.started)
495        test()
496        self.assertEqual(context.exc, (None, None, None))
497
498
499    def test_decorator_with_exception(self):
500        context = mycontext()
501
502        @context
503        def test():
504            self.assertIsNone(context.exc)
505            self.assertTrue(context.started)
506            raise NameError('foo')
507
508        with self.assertRaisesRegex(NameError, 'foo'):
509            test()
510        self.assertIsNotNone(context.exc)
511        self.assertIs(context.exc[0], NameError)
512
513
514    def test_decorating_method(self):
515        context = mycontext()
516
517        class Test(object):
518
519            @context
520            def method(self, a, b, c=None):
521                self.a = a
522                self.b = b
523                self.c = c
524
525        # these tests are for argument passing when used as a decorator
526        test = Test()
527        test.method(1, 2)
528        self.assertEqual(test.a, 1)
529        self.assertEqual(test.b, 2)
530        self.assertEqual(test.c, None)
531
532        test = Test()
533        test.method('a', 'b', 'c')
534        self.assertEqual(test.a, 'a')
535        self.assertEqual(test.b, 'b')
536        self.assertEqual(test.c, 'c')
537
538        test = Test()
539        test.method(a=1, b=2)
540        self.assertEqual(test.a, 1)
541        self.assertEqual(test.b, 2)
542
543
544    def test_typo_enter(self):
545        class mycontext(ContextDecorator):
546            def __unter__(self):
547                pass
548            def __exit__(self, *exc):
549                pass
550
551        with self.assertRaisesRegex(TypeError, 'the context manager'):
552            with mycontext():
553                pass
554
555
556    def test_typo_exit(self):
557        class mycontext(ContextDecorator):
558            def __enter__(self):
559                pass
560            def __uxit__(self, *exc):
561                pass
562
563        with self.assertRaisesRegex(TypeError, 'the context manager.*__exit__'):
564            with mycontext():
565                pass
566
567
568    def test_contextdecorator_as_mixin(self):
569        class somecontext(object):
570            started = False
571            exc = None
572
573            def __enter__(self):
574                self.started = True
575                return self
576
577            def __exit__(self, *exc):
578                self.exc = exc
579
580        class mycontext(somecontext, ContextDecorator):
581            pass
582
583        context = mycontext()
584        @context
585        def test():
586            self.assertIsNone(context.exc)
587            self.assertTrue(context.started)
588        test()
589        self.assertEqual(context.exc, (None, None, None))
590
591
592    def test_contextmanager_as_decorator(self):
593        @contextmanager
594        def woohoo(y):
595            state.append(y)
596            yield
597            state.append(999)
598
599        state = []
600        @woohoo(1)
601        def test(x):
602            self.assertEqual(state, [1])
603            state.append(x)
604        test('something')
605        self.assertEqual(state, [1, 'something', 999])
606
607        # Issue #11647: Ensure the decorated function is 'reusable'
608        state = []
609        test('something else')
610        self.assertEqual(state, [1, 'something else', 999])
611
612
613class TestBaseExitStack:
614    exit_stack = None
615
616    @support.requires_docstrings
617    def test_instance_docs(self):
618        # Issue 19330: ensure context manager instances have good docstrings
619        cm_docstring = self.exit_stack.__doc__
620        obj = self.exit_stack()
621        self.assertEqual(obj.__doc__, cm_docstring)
622
623    def test_no_resources(self):
624        with self.exit_stack():
625            pass
626
627    def test_callback(self):
628        expected = [
629            ((), {}),
630            ((1,), {}),
631            ((1,2), {}),
632            ((), dict(example=1)),
633            ((1,), dict(example=1)),
634            ((1,2), dict(example=1)),
635            ((1,2), dict(self=3, callback=4)),
636        ]
637        result = []
638        def _exit(*args, **kwds):
639            """Test metadata propagation"""
640            result.append((args, kwds))
641        with self.exit_stack() as stack:
642            for args, kwds in reversed(expected):
643                if args and kwds:
644                    f = stack.callback(_exit, *args, **kwds)
645                elif args:
646                    f = stack.callback(_exit, *args)
647                elif kwds:
648                    f = stack.callback(_exit, **kwds)
649                else:
650                    f = stack.callback(_exit)
651                self.assertIs(f, _exit)
652            for wrapper in stack._exit_callbacks:
653                self.assertIs(wrapper[1].__wrapped__, _exit)
654                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
655                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
656        self.assertEqual(result, expected)
657
658        result = []
659        with self.exit_stack() as stack:
660            with self.assertRaises(TypeError):
661                stack.callback(arg=1)
662            with self.assertRaises(TypeError):
663                self.exit_stack.callback(arg=2)
664            with self.assertRaises(TypeError):
665                stack.callback(callback=_exit, arg=3)
666        self.assertEqual(result, [])
667
668    def test_push(self):
669        exc_raised = ZeroDivisionError
670        def _expect_exc(exc_type, exc, exc_tb):
671            self.assertIs(exc_type, exc_raised)
672        def _suppress_exc(*exc_details):
673            return True
674        def _expect_ok(exc_type, exc, exc_tb):
675            self.assertIsNone(exc_type)
676            self.assertIsNone(exc)
677            self.assertIsNone(exc_tb)
678        class ExitCM(object):
679            def __init__(self, check_exc):
680                self.check_exc = check_exc
681            def __enter__(self):
682                self.fail("Should not be called!")
683            def __exit__(self, *exc_details):
684                self.check_exc(*exc_details)
685        with self.exit_stack() as stack:
686            stack.push(_expect_ok)
687            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
688            cm = ExitCM(_expect_ok)
689            stack.push(cm)
690            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
691            stack.push(_suppress_exc)
692            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
693            cm = ExitCM(_expect_exc)
694            stack.push(cm)
695            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
696            stack.push(_expect_exc)
697            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
698            stack.push(_expect_exc)
699            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
700            1/0
701
702    def test_enter_context(self):
703        class TestCM(object):
704            def __enter__(self):
705                result.append(1)
706            def __exit__(self, *exc_details):
707                result.append(3)
708
709        result = []
710        cm = TestCM()
711        with self.exit_stack() as stack:
712            @stack.callback  # Registered first => cleaned up last
713            def _exit():
714                result.append(4)
715            self.assertIsNotNone(_exit)
716            stack.enter_context(cm)
717            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
718            result.append(2)
719        self.assertEqual(result, [1, 2, 3, 4])
720
721    def test_enter_context_errors(self):
722        class LacksEnterAndExit:
723            pass
724        class LacksEnter:
725            def __exit__(self, *exc_info):
726                pass
727        class LacksExit:
728            def __enter__(self):
729                pass
730
731        with self.exit_stack() as stack:
732            with self.assertRaisesRegex(TypeError, 'the context manager'):
733                stack.enter_context(LacksEnterAndExit())
734            with self.assertRaisesRegex(TypeError, 'the context manager'):
735                stack.enter_context(LacksEnter())
736            with self.assertRaisesRegex(TypeError, 'the context manager'):
737                stack.enter_context(LacksExit())
738            self.assertFalse(stack._exit_callbacks)
739
740    def test_close(self):
741        result = []
742        with self.exit_stack() as stack:
743            @stack.callback
744            def _exit():
745                result.append(1)
746            self.assertIsNotNone(_exit)
747            stack.close()
748            result.append(2)
749        self.assertEqual(result, [1, 2])
750
751    def test_pop_all(self):
752        result = []
753        with self.exit_stack() as stack:
754            @stack.callback
755            def _exit():
756                result.append(3)
757            self.assertIsNotNone(_exit)
758            new_stack = stack.pop_all()
759            result.append(1)
760        result.append(2)
761        new_stack.close()
762        self.assertEqual(result, [1, 2, 3])
763
764    def test_exit_raise(self):
765        with self.assertRaises(ZeroDivisionError):
766            with self.exit_stack() as stack:
767                stack.push(lambda *exc: False)
768                1/0
769
770    def test_exit_suppress(self):
771        with self.exit_stack() as stack:
772            stack.push(lambda *exc: True)
773            1/0
774
775    def test_exit_exception_traceback(self):
776        # This test captures the current behavior of ExitStack so that we know
777        # if we ever unintendedly change it. It is not a statement of what the
778        # desired behavior is (for instance, we may want to remove some of the
779        # internal contextlib frames).
780
781        def raise_exc(exc):
782            raise exc
783
784        try:
785            with self.exit_stack() as stack:
786                stack.callback(raise_exc, ValueError)
787                1/0
788        except ValueError as e:
789            exc = e
790
791        self.assertIsInstance(exc, ValueError)
792        ve_frames = traceback.extract_tb(exc.__traceback__)
793        expected = \
794            [('test_exit_exception_traceback', 'with self.exit_stack() as stack:')] + \
795            self.callback_error_internal_frames + \
796            [('_exit_wrapper', 'callback(*args, **kwds)'),
797             ('raise_exc', 'raise exc')]
798
799        self.assertEqual(
800            [(f.name, f.line) for f in ve_frames], expected)
801
802        self.assertIsInstance(exc.__context__, ZeroDivisionError)
803        zde_frames = traceback.extract_tb(exc.__context__.__traceback__)
804        self.assertEqual([(f.name, f.line) for f in zde_frames],
805                         [('test_exit_exception_traceback', '1/0')])
806
807    def test_exit_exception_chaining_reference(self):
808        # Sanity check to make sure that ExitStack chaining matches
809        # actual nested with statements
810        class RaiseExc:
811            def __init__(self, exc):
812                self.exc = exc
813            def __enter__(self):
814                return self
815            def __exit__(self, *exc_details):
816                raise self.exc
817
818        class RaiseExcWithContext:
819            def __init__(self, outer, inner):
820                self.outer = outer
821                self.inner = inner
822            def __enter__(self):
823                return self
824            def __exit__(self, *exc_details):
825                try:
826                    raise self.inner
827                except:
828                    raise self.outer
829
830        class SuppressExc:
831            def __enter__(self):
832                return self
833            def __exit__(self, *exc_details):
834                type(self).saved_details = exc_details
835                return True
836
837        try:
838            with RaiseExc(IndexError):
839                with RaiseExcWithContext(KeyError, AttributeError):
840                    with SuppressExc():
841                        with RaiseExc(ValueError):
842                            1 / 0
843        except IndexError as exc:
844            self.assertIsInstance(exc.__context__, KeyError)
845            self.assertIsInstance(exc.__context__.__context__, AttributeError)
846            # Inner exceptions were suppressed
847            self.assertIsNone(exc.__context__.__context__.__context__)
848        else:
849            self.fail("Expected IndexError, but no exception was raised")
850        # Check the inner exceptions
851        inner_exc = SuppressExc.saved_details[1]
852        self.assertIsInstance(inner_exc, ValueError)
853        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
854
855    def test_exit_exception_chaining(self):
856        # Ensure exception chaining matches the reference behaviour
857        def raise_exc(exc):
858            raise exc
859
860        saved_details = None
861        def suppress_exc(*exc_details):
862            nonlocal saved_details
863            saved_details = exc_details
864            return True
865
866        try:
867            with self.exit_stack() as stack:
868                stack.callback(raise_exc, IndexError)
869                stack.callback(raise_exc, KeyError)
870                stack.callback(raise_exc, AttributeError)
871                stack.push(suppress_exc)
872                stack.callback(raise_exc, ValueError)
873                1 / 0
874        except IndexError as exc:
875            self.assertIsInstance(exc.__context__, KeyError)
876            self.assertIsInstance(exc.__context__.__context__, AttributeError)
877            # Inner exceptions were suppressed
878            self.assertIsNone(exc.__context__.__context__.__context__)
879        else:
880            self.fail("Expected IndexError, but no exception was raised")
881        # Check the inner exceptions
882        inner_exc = saved_details[1]
883        self.assertIsInstance(inner_exc, ValueError)
884        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
885
886    def test_exit_exception_explicit_none_context(self):
887        # Ensure ExitStack chaining matches actual nested `with` statements
888        # regarding explicit __context__ = None.
889
890        class MyException(Exception):
891            pass
892
893        @contextmanager
894        def my_cm():
895            try:
896                yield
897            except BaseException:
898                exc = MyException()
899                try:
900                    raise exc
901                finally:
902                    exc.__context__ = None
903
904        @contextmanager
905        def my_cm_with_exit_stack():
906            with self.exit_stack() as stack:
907                stack.enter_context(my_cm())
908                yield stack
909
910        for cm in (my_cm, my_cm_with_exit_stack):
911            with self.subTest():
912                try:
913                    with cm():
914                        raise IndexError()
915                except MyException as exc:
916                    self.assertIsNone(exc.__context__)
917                else:
918                    self.fail("Expected IndexError, but no exception was raised")
919
920    def test_exit_exception_non_suppressing(self):
921        # http://bugs.python.org/issue19092
922        def raise_exc(exc):
923            raise exc
924
925        def suppress_exc(*exc_details):
926            return True
927
928        try:
929            with self.exit_stack() as stack:
930                stack.callback(lambda: None)
931                stack.callback(raise_exc, IndexError)
932        except Exception as exc:
933            self.assertIsInstance(exc, IndexError)
934        else:
935            self.fail("Expected IndexError, but no exception was raised")
936
937        try:
938            with self.exit_stack() as stack:
939                stack.callback(raise_exc, KeyError)
940                stack.push(suppress_exc)
941                stack.callback(raise_exc, IndexError)
942        except Exception as exc:
943            self.assertIsInstance(exc, KeyError)
944        else:
945            self.fail("Expected KeyError, but no exception was raised")
946
947    def test_exit_exception_with_correct_context(self):
948        # http://bugs.python.org/issue20317
949        @contextmanager
950        def gets_the_context_right(exc):
951            try:
952                yield
953            finally:
954                raise exc
955
956        exc1 = Exception(1)
957        exc2 = Exception(2)
958        exc3 = Exception(3)
959        exc4 = Exception(4)
960
961        # The contextmanager already fixes the context, so prior to the
962        # fix, ExitStack would try to fix it *again* and get into an
963        # infinite self-referential loop
964        try:
965            with self.exit_stack() as stack:
966                stack.enter_context(gets_the_context_right(exc4))
967                stack.enter_context(gets_the_context_right(exc3))
968                stack.enter_context(gets_the_context_right(exc2))
969                raise exc1
970        except Exception as exc:
971            self.assertIs(exc, exc4)
972            self.assertIs(exc.__context__, exc3)
973            self.assertIs(exc.__context__.__context__, exc2)
974            self.assertIs(exc.__context__.__context__.__context__, exc1)
975            self.assertIsNone(
976                       exc.__context__.__context__.__context__.__context__)
977
978    def test_exit_exception_with_existing_context(self):
979        # Addresses a lack of test coverage discovered after checking in a
980        # fix for issue 20317 that still contained debugging code.
981        def raise_nested(inner_exc, outer_exc):
982            try:
983                raise inner_exc
984            finally:
985                raise outer_exc
986        exc1 = Exception(1)
987        exc2 = Exception(2)
988        exc3 = Exception(3)
989        exc4 = Exception(4)
990        exc5 = Exception(5)
991        try:
992            with self.exit_stack() as stack:
993                stack.callback(raise_nested, exc4, exc5)
994                stack.callback(raise_nested, exc2, exc3)
995                raise exc1
996        except Exception as exc:
997            self.assertIs(exc, exc5)
998            self.assertIs(exc.__context__, exc4)
999            self.assertIs(exc.__context__.__context__, exc3)
1000            self.assertIs(exc.__context__.__context__.__context__, exc2)
1001            self.assertIs(
1002                 exc.__context__.__context__.__context__.__context__, exc1)
1003            self.assertIsNone(
1004                exc.__context__.__context__.__context__.__context__.__context__)
1005
1006    def test_body_exception_suppress(self):
1007        def suppress_exc(*exc_details):
1008            return True
1009        try:
1010            with self.exit_stack() as stack:
1011                stack.push(suppress_exc)
1012                1/0
1013        except IndexError as exc:
1014            self.fail("Expected no exception, got IndexError")
1015
1016    def test_exit_exception_chaining_suppress(self):
1017        with self.exit_stack() as stack:
1018            stack.push(lambda *exc: True)
1019            stack.push(lambda *exc: 1/0)
1020            stack.push(lambda *exc: {}[1])
1021
1022    def test_excessive_nesting(self):
1023        # The original implementation would die with RecursionError here
1024        with self.exit_stack() as stack:
1025            for i in range(10000):
1026                stack.callback(int)
1027
1028    def test_instance_bypass(self):
1029        class Example(object): pass
1030        cm = Example()
1031        cm.__enter__ = object()
1032        cm.__exit__ = object()
1033        stack = self.exit_stack()
1034        with self.assertRaisesRegex(TypeError, 'the context manager'):
1035            stack.enter_context(cm)
1036        stack.push(cm)
1037        self.assertIs(stack._exit_callbacks[-1][1], cm)
1038
1039    def test_dont_reraise_RuntimeError(self):
1040        # https://bugs.python.org/issue27122
1041        class UniqueException(Exception): pass
1042        class UniqueRuntimeError(RuntimeError): pass
1043
1044        @contextmanager
1045        def second():
1046            try:
1047                yield 1
1048            except Exception as exc:
1049                raise UniqueException("new exception") from exc
1050
1051        @contextmanager
1052        def first():
1053            try:
1054                yield 1
1055            except Exception as exc:
1056                raise exc
1057
1058        # The UniqueRuntimeError should be caught by second()'s exception
1059        # handler which chain raised a new UniqueException.
1060        with self.assertRaises(UniqueException) as err_ctx:
1061            with self.exit_stack() as es_ctx:
1062                es_ctx.enter_context(second())
1063                es_ctx.enter_context(first())
1064                raise UniqueRuntimeError("please no infinite loop.")
1065
1066        exc = err_ctx.exception
1067        self.assertIsInstance(exc, UniqueException)
1068        self.assertIsInstance(exc.__context__, UniqueRuntimeError)
1069        self.assertIsNone(exc.__context__.__context__)
1070        self.assertIsNone(exc.__context__.__cause__)
1071        self.assertIs(exc.__cause__, exc.__context__)
1072
1073
1074class TestExitStack(TestBaseExitStack, unittest.TestCase):
1075    exit_stack = ExitStack
1076    callback_error_internal_frames = [
1077        ('__exit__', 'raise exc_details[1]'),
1078        ('__exit__', 'if cb(*exc_details):'),
1079    ]
1080
1081
1082class TestRedirectStream:
1083
1084    redirect_stream = None
1085    orig_stream = None
1086
1087    @support.requires_docstrings
1088    def test_instance_docs(self):
1089        # Issue 19330: ensure context manager instances have good docstrings
1090        cm_docstring = self.redirect_stream.__doc__
1091        obj = self.redirect_stream(None)
1092        self.assertEqual(obj.__doc__, cm_docstring)
1093
1094    def test_no_redirect_in_init(self):
1095        orig_stdout = getattr(sys, self.orig_stream)
1096        self.redirect_stream(None)
1097        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1098
1099    def test_redirect_to_string_io(self):
1100        f = io.StringIO()
1101        msg = "Consider an API like help(), which prints directly to stdout"
1102        orig_stdout = getattr(sys, self.orig_stream)
1103        with self.redirect_stream(f):
1104            print(msg, file=getattr(sys, self.orig_stream))
1105        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1106        s = f.getvalue().strip()
1107        self.assertEqual(s, msg)
1108
1109    def test_enter_result_is_target(self):
1110        f = io.StringIO()
1111        with self.redirect_stream(f) as enter_result:
1112            self.assertIs(enter_result, f)
1113
1114    def test_cm_is_reusable(self):
1115        f = io.StringIO()
1116        write_to_f = self.redirect_stream(f)
1117        orig_stdout = getattr(sys, self.orig_stream)
1118        with write_to_f:
1119            print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1120        with write_to_f:
1121            print("World!", file=getattr(sys, self.orig_stream))
1122        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1123        s = f.getvalue()
1124        self.assertEqual(s, "Hello World!\n")
1125
1126    def test_cm_is_reentrant(self):
1127        f = io.StringIO()
1128        write_to_f = self.redirect_stream(f)
1129        orig_stdout = getattr(sys, self.orig_stream)
1130        with write_to_f:
1131            print("Hello", end=" ", file=getattr(sys, self.orig_stream))
1132            with write_to_f:
1133                print("World!", file=getattr(sys, self.orig_stream))
1134        self.assertIs(getattr(sys, self.orig_stream), orig_stdout)
1135        s = f.getvalue()
1136        self.assertEqual(s, "Hello World!\n")
1137
1138
1139class TestRedirectStdout(TestRedirectStream, unittest.TestCase):
1140
1141    redirect_stream = redirect_stdout
1142    orig_stream = "stdout"
1143
1144
1145class TestRedirectStderr(TestRedirectStream, unittest.TestCase):
1146
1147    redirect_stream = redirect_stderr
1148    orig_stream = "stderr"
1149
1150
1151class TestSuppress(unittest.TestCase):
1152
1153    @support.requires_docstrings
1154    def test_instance_docs(self):
1155        # Issue 19330: ensure context manager instances have good docstrings
1156        cm_docstring = suppress.__doc__
1157        obj = suppress()
1158        self.assertEqual(obj.__doc__, cm_docstring)
1159
1160    def test_no_result_from_enter(self):
1161        with suppress(ValueError) as enter_result:
1162            self.assertIsNone(enter_result)
1163
1164    def test_no_exception(self):
1165        with suppress(ValueError):
1166            self.assertEqual(pow(2, 5), 32)
1167
1168    def test_exact_exception(self):
1169        with suppress(TypeError):
1170            len(5)
1171
1172    def test_exception_hierarchy(self):
1173        with suppress(LookupError):
1174            'Hello'[50]
1175
1176    def test_other_exception(self):
1177        with self.assertRaises(ZeroDivisionError):
1178            with suppress(TypeError):
1179                1/0
1180
1181    def test_no_args(self):
1182        with self.assertRaises(ZeroDivisionError):
1183            with suppress():
1184                1/0
1185
1186    def test_multiple_exception_args(self):
1187        with suppress(ZeroDivisionError, TypeError):
1188            1/0
1189        with suppress(ZeroDivisionError, TypeError):
1190            len(5)
1191
1192    def test_cm_is_reentrant(self):
1193        ignore_exceptions = suppress(Exception)
1194        with ignore_exceptions:
1195            pass
1196        with ignore_exceptions:
1197            len(5)
1198        with ignore_exceptions:
1199            with ignore_exceptions: # Check nested usage
1200                len(5)
1201            outer_continued = True
1202            1/0
1203        self.assertTrue(outer_continued)
1204
1205
1206class TestChdir(unittest.TestCase):
1207    def make_relative_path(self, *parts):
1208        return os.path.join(
1209            os.path.dirname(os.path.realpath(__file__)),
1210            *parts,
1211        )
1212
1213    def test_simple(self):
1214        old_cwd = os.getcwd()
1215        target = self.make_relative_path('data')
1216        self.assertNotEqual(old_cwd, target)
1217
1218        with chdir(target):
1219            self.assertEqual(os.getcwd(), target)
1220        self.assertEqual(os.getcwd(), old_cwd)
1221
1222    def test_reentrant(self):
1223        old_cwd = os.getcwd()
1224        target1 = self.make_relative_path('data')
1225        target2 = self.make_relative_path('ziptestdata')
1226        self.assertNotIn(old_cwd, (target1, target2))
1227        chdir1, chdir2 = chdir(target1), chdir(target2)
1228
1229        with chdir1:
1230            self.assertEqual(os.getcwd(), target1)
1231            with chdir2:
1232                self.assertEqual(os.getcwd(), target2)
1233                with chdir1:
1234                    self.assertEqual(os.getcwd(), target1)
1235                self.assertEqual(os.getcwd(), target2)
1236            self.assertEqual(os.getcwd(), target1)
1237        self.assertEqual(os.getcwd(), old_cwd)
1238
1239    def test_exception(self):
1240        old_cwd = os.getcwd()
1241        target = self.make_relative_path('data')
1242        self.assertNotEqual(old_cwd, target)
1243
1244        try:
1245            with chdir(target):
1246                self.assertEqual(os.getcwd(), target)
1247                raise RuntimeError("boom")
1248        except RuntimeError as re:
1249            self.assertEqual(str(re), "boom")
1250        self.assertEqual(os.getcwd(), old_cwd)
1251
1252
1253if __name__ == "__main__":
1254    unittest.main()
1255