1import asyncio
2from contextlib import (
3    asynccontextmanager, AbstractAsyncContextManager,
4    AsyncExitStack, nullcontext, aclosing, contextmanager)
5import functools
6from test import support
7import unittest
8import traceback
9
10from test.test_contextlib import TestBaseExitStack
11
12support.requires_working_socket(module=True)
13
14def _async_test(func):
15    """Decorator to turn an async function into a test case."""
16    @functools.wraps(func)
17    def wrapper(*args, **kwargs):
18        coro = func(*args, **kwargs)
19        asyncio.run(coro)
20    return wrapper
21
22def tearDownModule():
23    asyncio.set_event_loop_policy(None)
24
25
26class TestAbstractAsyncContextManager(unittest.TestCase):
27
28    @_async_test
29    async def test_enter(self):
30        class DefaultEnter(AbstractAsyncContextManager):
31            async def __aexit__(self, *args):
32                await super().__aexit__(*args)
33
34        manager = DefaultEnter()
35        self.assertIs(await manager.__aenter__(), manager)
36
37        async with manager as context:
38            self.assertIs(manager, context)
39
40    @_async_test
41    async def test_async_gen_propagates_generator_exit(self):
42        # A regression test for https://bugs.python.org/issue33786.
43
44        @asynccontextmanager
45        async def ctx():
46            yield
47
48        async def gen():
49            async with ctx():
50                yield 11
51
52        ret = []
53        exc = ValueError(22)
54        with self.assertRaises(ValueError):
55            async with ctx():
56                async for val in gen():
57                    ret.append(val)
58                    raise exc
59
60        self.assertEqual(ret, [11])
61
62    def test_exit_is_abstract(self):
63        class MissingAexit(AbstractAsyncContextManager):
64            pass
65
66        with self.assertRaises(TypeError):
67            MissingAexit()
68
69    def test_structural_subclassing(self):
70        class ManagerFromScratch:
71            async def __aenter__(self):
72                return self
73            async def __aexit__(self, exc_type, exc_value, traceback):
74                return None
75
76        self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
77
78        class DefaultEnter(AbstractAsyncContextManager):
79            async def __aexit__(self, *args):
80                await super().__aexit__(*args)
81
82        self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
83
84        class NoneAenter(ManagerFromScratch):
85            __aenter__ = None
86
87        self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
88
89        class NoneAexit(ManagerFromScratch):
90            __aexit__ = None
91
92        self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
93
94
95class AsyncContextManagerTestCase(unittest.TestCase):
96
97    @_async_test
98    async def test_contextmanager_plain(self):
99        state = []
100        @asynccontextmanager
101        async def woohoo():
102            state.append(1)
103            yield 42
104            state.append(999)
105        async with woohoo() as x:
106            self.assertEqual(state, [1])
107            self.assertEqual(x, 42)
108            state.append(x)
109        self.assertEqual(state, [1, 42, 999])
110
111    @_async_test
112    async def test_contextmanager_finally(self):
113        state = []
114        @asynccontextmanager
115        async def woohoo():
116            state.append(1)
117            try:
118                yield 42
119            finally:
120                state.append(999)
121        with self.assertRaises(ZeroDivisionError):
122            async with woohoo() as x:
123                self.assertEqual(state, [1])
124                self.assertEqual(x, 42)
125                state.append(x)
126                raise ZeroDivisionError()
127        self.assertEqual(state, [1, 42, 999])
128
129    @_async_test
130    async def test_contextmanager_traceback(self):
131        @asynccontextmanager
132        async def f():
133            yield
134
135        try:
136            async with f():
137                1/0
138        except ZeroDivisionError as e:
139            frames = traceback.extract_tb(e.__traceback__)
140
141        self.assertEqual(len(frames), 1)
142        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
143        self.assertEqual(frames[0].line, '1/0')
144
145        # Repeat with RuntimeError (which goes through a different code path)
146        class RuntimeErrorSubclass(RuntimeError):
147            pass
148
149        try:
150            async with f():
151                raise RuntimeErrorSubclass(42)
152        except RuntimeErrorSubclass as e:
153            frames = traceback.extract_tb(e.__traceback__)
154
155        self.assertEqual(len(frames), 1)
156        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
157        self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
158
159        class StopIterationSubclass(StopIteration):
160            pass
161
162        class StopAsyncIterationSubclass(StopAsyncIteration):
163            pass
164
165        for stop_exc in (
166            StopIteration('spam'),
167            StopAsyncIteration('ham'),
168            StopIterationSubclass('spam'),
169            StopAsyncIterationSubclass('spam')
170        ):
171            with self.subTest(type=type(stop_exc)):
172                try:
173                    async with f():
174                        raise stop_exc
175                except type(stop_exc) as e:
176                    self.assertIs(e, stop_exc)
177                    frames = traceback.extract_tb(e.__traceback__)
178                else:
179                    self.fail(f'{stop_exc} was suppressed')
180
181                self.assertEqual(len(frames), 1)
182                self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
183                self.assertEqual(frames[0].line, 'raise stop_exc')
184
185    @_async_test
186    async def test_contextmanager_no_reraise(self):
187        @asynccontextmanager
188        async def whee():
189            yield
190        ctx = whee()
191        await ctx.__aenter__()
192        # Calling __aexit__ should not result in an exception
193        self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
194
195    @_async_test
196    async def test_contextmanager_trap_yield_after_throw(self):
197        @asynccontextmanager
198        async def whoo():
199            try:
200                yield
201            except:
202                yield
203        ctx = whoo()
204        await ctx.__aenter__()
205        with self.assertRaises(RuntimeError):
206            await ctx.__aexit__(TypeError, TypeError('foo'), None)
207
208    @_async_test
209    async def test_contextmanager_trap_no_yield(self):
210        @asynccontextmanager
211        async def whoo():
212            if False:
213                yield
214        ctx = whoo()
215        with self.assertRaises(RuntimeError):
216            await ctx.__aenter__()
217
218    @_async_test
219    async def test_contextmanager_trap_second_yield(self):
220        @asynccontextmanager
221        async def whoo():
222            yield
223            yield
224        ctx = whoo()
225        await ctx.__aenter__()
226        with self.assertRaises(RuntimeError):
227            await ctx.__aexit__(None, None, None)
228
229    @_async_test
230    async def test_contextmanager_non_normalised(self):
231        @asynccontextmanager
232        async def whoo():
233            try:
234                yield
235            except RuntimeError:
236                raise SyntaxError
237
238        ctx = whoo()
239        await ctx.__aenter__()
240        with self.assertRaises(SyntaxError):
241            await ctx.__aexit__(RuntimeError, None, None)
242
243    @_async_test
244    async def test_contextmanager_except(self):
245        state = []
246        @asynccontextmanager
247        async def woohoo():
248            state.append(1)
249            try:
250                yield 42
251            except ZeroDivisionError as e:
252                state.append(e.args[0])
253                self.assertEqual(state, [1, 42, 999])
254        async with woohoo() as x:
255            self.assertEqual(state, [1])
256            self.assertEqual(x, 42)
257            state.append(x)
258            raise ZeroDivisionError(999)
259        self.assertEqual(state, [1, 42, 999])
260
261    @_async_test
262    async def test_contextmanager_except_stopiter(self):
263        @asynccontextmanager
264        async def woohoo():
265            yield
266
267        class StopIterationSubclass(StopIteration):
268            pass
269
270        class StopAsyncIterationSubclass(StopAsyncIteration):
271            pass
272
273        for stop_exc in (
274            StopIteration('spam'),
275            StopAsyncIteration('ham'),
276            StopIterationSubclass('spam'),
277            StopAsyncIterationSubclass('spam')
278        ):
279            with self.subTest(type=type(stop_exc)):
280                try:
281                    async with woohoo():
282                        raise stop_exc
283                except Exception as ex:
284                    self.assertIs(ex, stop_exc)
285                else:
286                    self.fail(f'{stop_exc} was suppressed')
287
288    @_async_test
289    async def test_contextmanager_wrap_runtimeerror(self):
290        @asynccontextmanager
291        async def woohoo():
292            try:
293                yield
294            except Exception as exc:
295                raise RuntimeError(f'caught {exc}') from exc
296
297        with self.assertRaises(RuntimeError):
298            async with woohoo():
299                1 / 0
300
301        # If the context manager wrapped StopAsyncIteration in a RuntimeError,
302        # we also unwrap it, because we can't tell whether the wrapping was
303        # done by the generator machinery or by the generator itself.
304        with self.assertRaises(StopAsyncIteration):
305            async with woohoo():
306                raise StopAsyncIteration
307
308    def _create_contextmanager_attribs(self):
309        def attribs(**kw):
310            def decorate(func):
311                for k,v in kw.items():
312                    setattr(func,k,v)
313                return func
314            return decorate
315        @asynccontextmanager
316        @attribs(foo='bar')
317        async def baz(spam):
318            """Whee!"""
319            yield
320        return baz
321
322    def test_contextmanager_attribs(self):
323        baz = self._create_contextmanager_attribs()
324        self.assertEqual(baz.__name__,'baz')
325        self.assertEqual(baz.foo, 'bar')
326
327    @support.requires_docstrings
328    def test_contextmanager_doc_attrib(self):
329        baz = self._create_contextmanager_attribs()
330        self.assertEqual(baz.__doc__, "Whee!")
331
332    @support.requires_docstrings
333    @_async_test
334    async def test_instance_docstring_given_cm_docstring(self):
335        baz = self._create_contextmanager_attribs()(None)
336        self.assertEqual(baz.__doc__, "Whee!")
337        async with baz:
338            pass  # suppress warning
339
340    @_async_test
341    async def test_keywords(self):
342        # Ensure no keyword arguments are inhibited
343        @asynccontextmanager
344        async def woohoo(self, func, args, kwds):
345            yield (self, func, args, kwds)
346        async with woohoo(self=11, func=22, args=33, kwds=44) as target:
347            self.assertEqual(target, (11, 22, 33, 44))
348
349    @_async_test
350    async def test_recursive(self):
351        depth = 0
352        ncols = 0
353
354        @asynccontextmanager
355        async def woohoo():
356            nonlocal ncols
357            ncols += 1
358
359            nonlocal depth
360            before = depth
361            depth += 1
362            yield
363            depth -= 1
364            self.assertEqual(depth, before)
365
366        @woohoo()
367        async def recursive():
368            if depth < 10:
369                await recursive()
370
371        await recursive()
372
373        self.assertEqual(ncols, 10)
374        self.assertEqual(depth, 0)
375
376    @_async_test
377    async def test_decorator(self):
378        entered = False
379
380        @asynccontextmanager
381        async def context():
382            nonlocal entered
383            entered = True
384            yield
385            entered = False
386
387        @context()
388        async def test():
389            self.assertTrue(entered)
390
391        self.assertFalse(entered)
392        await test()
393        self.assertFalse(entered)
394
395    @_async_test
396    async def test_decorator_with_exception(self):
397        entered = False
398
399        @asynccontextmanager
400        async def context():
401            nonlocal entered
402            try:
403                entered = True
404                yield
405            finally:
406                entered = False
407
408        @context()
409        async def test():
410            self.assertTrue(entered)
411            raise NameError('foo')
412
413        self.assertFalse(entered)
414        with self.assertRaisesRegex(NameError, 'foo'):
415            await test()
416        self.assertFalse(entered)
417
418    @_async_test
419    async def test_decorating_method(self):
420
421        @asynccontextmanager
422        async def context():
423            yield
424
425
426        class Test(object):
427
428            @context()
429            async def method(self, a, b, c=None):
430                self.a = a
431                self.b = b
432                self.c = c
433
434        # these tests are for argument passing when used as a decorator
435        test = Test()
436        await test.method(1, 2)
437        self.assertEqual(test.a, 1)
438        self.assertEqual(test.b, 2)
439        self.assertEqual(test.c, None)
440
441        test = Test()
442        await test.method('a', 'b', 'c')
443        self.assertEqual(test.a, 'a')
444        self.assertEqual(test.b, 'b')
445        self.assertEqual(test.c, 'c')
446
447        test = Test()
448        await test.method(a=1, b=2)
449        self.assertEqual(test.a, 1)
450        self.assertEqual(test.b, 2)
451
452
453class AclosingTestCase(unittest.TestCase):
454
455    @support.requires_docstrings
456    def test_instance_docs(self):
457        cm_docstring = aclosing.__doc__
458        obj = aclosing(None)
459        self.assertEqual(obj.__doc__, cm_docstring)
460
461    @_async_test
462    async def test_aclosing(self):
463        state = []
464        class C:
465            async def aclose(self):
466                state.append(1)
467        x = C()
468        self.assertEqual(state, [])
469        async with aclosing(x) as y:
470            self.assertEqual(x, y)
471        self.assertEqual(state, [1])
472
473    @_async_test
474    async def test_aclosing_error(self):
475        state = []
476        class C:
477            async def aclose(self):
478                state.append(1)
479        x = C()
480        self.assertEqual(state, [])
481        with self.assertRaises(ZeroDivisionError):
482            async with aclosing(x) as y:
483                self.assertEqual(x, y)
484                1 / 0
485        self.assertEqual(state, [1])
486
487    @_async_test
488    async def test_aclosing_bpo41229(self):
489        state = []
490
491        @contextmanager
492        def sync_resource():
493            try:
494                yield
495            finally:
496                state.append(1)
497
498        async def agenfunc():
499            with sync_resource():
500                yield -1
501                yield -2
502
503        x = agenfunc()
504        self.assertEqual(state, [])
505        with self.assertRaises(ZeroDivisionError):
506            async with aclosing(x) as y:
507                self.assertEqual(x, y)
508                self.assertEqual(-1, await x.__anext__())
509                1 / 0
510        self.assertEqual(state, [1])
511
512
513class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
514    class SyncAsyncExitStack(AsyncExitStack):
515        @staticmethod
516        def run_coroutine(coro):
517            loop = asyncio.get_event_loop_policy().get_event_loop()
518            t = loop.create_task(coro)
519            t.add_done_callback(lambda f: loop.stop())
520            loop.run_forever()
521
522            exc = t.exception()
523            if not exc:
524                return t.result()
525            else:
526                context = exc.__context__
527
528                try:
529                    raise exc
530                except:
531                    exc.__context__ = context
532                    raise exc
533
534        def close(self):
535            return self.run_coroutine(self.aclose())
536
537        def __enter__(self):
538            return self.run_coroutine(self.__aenter__())
539
540        def __exit__(self, *exc_details):
541            return self.run_coroutine(self.__aexit__(*exc_details))
542
543    exit_stack = SyncAsyncExitStack
544    callback_error_internal_frames = [
545        ('__exit__', 'return self.run_coroutine(self.__aexit__(*exc_details))'),
546        ('run_coroutine', 'raise exc'),
547        ('run_coroutine', 'raise exc'),
548        ('__aexit__', 'raise exc_details[1]'),
549        ('__aexit__', 'cb_suppress = cb(*exc_details)'),
550    ]
551
552    def setUp(self):
553        self.loop = asyncio.new_event_loop()
554        asyncio.set_event_loop(self.loop)
555        self.addCleanup(self.loop.close)
556        self.addCleanup(asyncio.set_event_loop_policy, None)
557
558    @_async_test
559    async def test_async_callback(self):
560        expected = [
561            ((), {}),
562            ((1,), {}),
563            ((1,2), {}),
564            ((), dict(example=1)),
565            ((1,), dict(example=1)),
566            ((1,2), dict(example=1)),
567        ]
568        result = []
569        async def _exit(*args, **kwds):
570            """Test metadata propagation"""
571            result.append((args, kwds))
572
573        async with AsyncExitStack() as stack:
574            for args, kwds in reversed(expected):
575                if args and kwds:
576                    f = stack.push_async_callback(_exit, *args, **kwds)
577                elif args:
578                    f = stack.push_async_callback(_exit, *args)
579                elif kwds:
580                    f = stack.push_async_callback(_exit, **kwds)
581                else:
582                    f = stack.push_async_callback(_exit)
583                self.assertIs(f, _exit)
584            for wrapper in stack._exit_callbacks:
585                self.assertIs(wrapper[1].__wrapped__, _exit)
586                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
587                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
588
589        self.assertEqual(result, expected)
590
591        result = []
592        async with AsyncExitStack() as stack:
593            with self.assertRaises(TypeError):
594                stack.push_async_callback(arg=1)
595            with self.assertRaises(TypeError):
596                self.exit_stack.push_async_callback(arg=2)
597            with self.assertRaises(TypeError):
598                stack.push_async_callback(callback=_exit, arg=3)
599        self.assertEqual(result, [])
600
601    @_async_test
602    async def test_async_push(self):
603        exc_raised = ZeroDivisionError
604        async def _expect_exc(exc_type, exc, exc_tb):
605            self.assertIs(exc_type, exc_raised)
606        async def _suppress_exc(*exc_details):
607            return True
608        async def _expect_ok(exc_type, exc, exc_tb):
609            self.assertIsNone(exc_type)
610            self.assertIsNone(exc)
611            self.assertIsNone(exc_tb)
612        class ExitCM(object):
613            def __init__(self, check_exc):
614                self.check_exc = check_exc
615            async def __aenter__(self):
616                self.fail("Should not be called!")
617            async def __aexit__(self, *exc_details):
618                await self.check_exc(*exc_details)
619
620        async with self.exit_stack() as stack:
621            stack.push_async_exit(_expect_ok)
622            self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
623            cm = ExitCM(_expect_ok)
624            stack.push_async_exit(cm)
625            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
626            stack.push_async_exit(_suppress_exc)
627            self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
628            cm = ExitCM(_expect_exc)
629            stack.push_async_exit(cm)
630            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
631            stack.push_async_exit(_expect_exc)
632            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
633            stack.push_async_exit(_expect_exc)
634            self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
635            1/0
636
637    @_async_test
638    async def test_enter_async_context(self):
639        class TestCM(object):
640            async def __aenter__(self):
641                result.append(1)
642            async def __aexit__(self, *exc_details):
643                result.append(3)
644
645        result = []
646        cm = TestCM()
647
648        async with AsyncExitStack() as stack:
649            @stack.push_async_callback  # Registered first => cleaned up last
650            async def _exit():
651                result.append(4)
652            self.assertIsNotNone(_exit)
653            await stack.enter_async_context(cm)
654            self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
655            result.append(2)
656
657        self.assertEqual(result, [1, 2, 3, 4])
658
659    @_async_test
660    async def test_enter_async_context_errors(self):
661        class LacksEnterAndExit:
662            pass
663        class LacksEnter:
664            async def __aexit__(self, *exc_info):
665                pass
666        class LacksExit:
667            async def __aenter__(self):
668                pass
669
670        async with self.exit_stack() as stack:
671            with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
672                await stack.enter_async_context(LacksEnterAndExit())
673            with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
674                await stack.enter_async_context(LacksEnter())
675            with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
676                await stack.enter_async_context(LacksExit())
677            self.assertFalse(stack._exit_callbacks)
678
679    @_async_test
680    async def test_async_exit_exception_chaining(self):
681        # Ensure exception chaining matches the reference behaviour
682        async def raise_exc(exc):
683            raise exc
684
685        saved_details = None
686        async def suppress_exc(*exc_details):
687            nonlocal saved_details
688            saved_details = exc_details
689            return True
690
691        try:
692            async with self.exit_stack() as stack:
693                stack.push_async_callback(raise_exc, IndexError)
694                stack.push_async_callback(raise_exc, KeyError)
695                stack.push_async_callback(raise_exc, AttributeError)
696                stack.push_async_exit(suppress_exc)
697                stack.push_async_callback(raise_exc, ValueError)
698                1 / 0
699        except IndexError as exc:
700            self.assertIsInstance(exc.__context__, KeyError)
701            self.assertIsInstance(exc.__context__.__context__, AttributeError)
702            # Inner exceptions were suppressed
703            self.assertIsNone(exc.__context__.__context__.__context__)
704        else:
705            self.fail("Expected IndexError, but no exception was raised")
706        # Check the inner exceptions
707        inner_exc = saved_details[1]
708        self.assertIsInstance(inner_exc, ValueError)
709        self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
710
711    @_async_test
712    async def test_async_exit_exception_explicit_none_context(self):
713        # Ensure AsyncExitStack chaining matches actual nested `with` statements
714        # regarding explicit __context__ = None.
715
716        class MyException(Exception):
717            pass
718
719        @asynccontextmanager
720        async def my_cm():
721            try:
722                yield
723            except BaseException:
724                exc = MyException()
725                try:
726                    raise exc
727                finally:
728                    exc.__context__ = None
729
730        @asynccontextmanager
731        async def my_cm_with_exit_stack():
732            async with self.exit_stack() as stack:
733                await stack.enter_async_context(my_cm())
734                yield stack
735
736        for cm in (my_cm, my_cm_with_exit_stack):
737            with self.subTest():
738                try:
739                    async with cm():
740                        raise IndexError()
741                except MyException as exc:
742                    self.assertIsNone(exc.__context__)
743                else:
744                    self.fail("Expected IndexError, but no exception was raised")
745
746    @_async_test
747    async def test_instance_bypass_async(self):
748        class Example(object): pass
749        cm = Example()
750        cm.__aenter__ = object()
751        cm.__aexit__ = object()
752        stack = self.exit_stack()
753        with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
754            await stack.enter_async_context(cm)
755        stack.push_async_exit(cm)
756        self.assertIs(stack._exit_callbacks[-1][1], cm)
757
758
759class TestAsyncNullcontext(unittest.TestCase):
760    @_async_test
761    async def test_async_nullcontext(self):
762        class C:
763            pass
764        c = C()
765        async with nullcontext(c) as c_in:
766            self.assertIs(c_in, c)
767
768
769if __name__ == '__main__':
770    unittest.main()
771