1import inspect
2import types
3import unittest
4import contextlib
5
6from test.support.import_helper import import_module
7from test.support import gc_collect, requires_working_socket
8asyncio = import_module("asyncio")
9
10
11requires_working_socket(module=True)
12
13_no_default = object()
14
15
16class AwaitException(Exception):
17    pass
18
19
20@types.coroutine
21def awaitable(*, throw=False):
22    if throw:
23        yield ('throw',)
24    else:
25        yield ('result',)
26
27
28def run_until_complete(coro):
29    exc = False
30    while True:
31        try:
32            if exc:
33                exc = False
34                fut = coro.throw(AwaitException)
35            else:
36                fut = coro.send(None)
37        except StopIteration as ex:
38            return ex.args[0]
39
40        if fut == ('throw',):
41            exc = True
42
43
44def to_list(gen):
45    async def iterate():
46        res = []
47        async for i in gen:
48            res.append(i)
49        return res
50
51    return run_until_complete(iterate())
52
53
54def py_anext(iterator, default=_no_default):
55    """Pure-Python implementation of anext() for testing purposes.
56
57    Closely matches the builtin anext() C implementation.
58    Can be used to compare the built-in implementation of the inner
59    coroutines machinery to C-implementation of __anext__() and send()
60    or throw() on the returned generator.
61    """
62
63    try:
64        __anext__ = type(iterator).__anext__
65    except AttributeError:
66        raise TypeError(f'{iterator!r} is not an async iterator')
67
68    if default is _no_default:
69        return __anext__(iterator)
70
71    async def anext_impl():
72        try:
73            # The C code is way more low-level than this, as it implements
74            # all methods of the iterator protocol. In this implementation
75            # we're relying on higher-level coroutine concepts, but that's
76            # exactly what we want -- crosstest pure-Python high-level
77            # implementation and low-level C anext() iterators.
78            return await __anext__(iterator)
79        except StopAsyncIteration:
80            return default
81
82    return anext_impl()
83
84
85class AsyncGenSyntaxTest(unittest.TestCase):
86
87    def test_async_gen_syntax_01(self):
88        code = '''async def foo():
89            await abc
90            yield from 123
91        '''
92
93        with self.assertRaisesRegex(SyntaxError, 'yield from.*inside async'):
94            exec(code, {}, {})
95
96    def test_async_gen_syntax_02(self):
97        code = '''async def foo():
98            yield from 123
99        '''
100
101        with self.assertRaisesRegex(SyntaxError, 'yield from.*inside async'):
102            exec(code, {}, {})
103
104    def test_async_gen_syntax_03(self):
105        code = '''async def foo():
106            await abc
107            yield
108            return 123
109        '''
110
111        with self.assertRaisesRegex(SyntaxError, 'return.*value.*async gen'):
112            exec(code, {}, {})
113
114    def test_async_gen_syntax_04(self):
115        code = '''async def foo():
116            yield
117            return 123
118        '''
119
120        with self.assertRaisesRegex(SyntaxError, 'return.*value.*async gen'):
121            exec(code, {}, {})
122
123    def test_async_gen_syntax_05(self):
124        code = '''async def foo():
125            if 0:
126                yield
127            return 12
128        '''
129
130        with self.assertRaisesRegex(SyntaxError, 'return.*value.*async gen'):
131            exec(code, {}, {})
132
133
134class AsyncGenTest(unittest.TestCase):
135
136    def compare_generators(self, sync_gen, async_gen):
137        def sync_iterate(g):
138            res = []
139            while True:
140                try:
141                    res.append(g.__next__())
142                except StopIteration:
143                    res.append('STOP')
144                    break
145                except Exception as ex:
146                    res.append(str(type(ex)))
147            return res
148
149        def async_iterate(g):
150            res = []
151            while True:
152                an = g.__anext__()
153                try:
154                    while True:
155                        try:
156                            an.__next__()
157                        except StopIteration as ex:
158                            if ex.args:
159                                res.append(ex.args[0])
160                                break
161                            else:
162                                res.append('EMPTY StopIteration')
163                                break
164                        except StopAsyncIteration:
165                            raise
166                        except Exception as ex:
167                            res.append(str(type(ex)))
168                            break
169                except StopAsyncIteration:
170                    res.append('STOP')
171                    break
172            return res
173
174        sync_gen_result = sync_iterate(sync_gen)
175        async_gen_result = async_iterate(async_gen)
176        self.assertEqual(sync_gen_result, async_gen_result)
177        return async_gen_result
178
179    def test_async_gen_iteration_01(self):
180        async def gen():
181            await awaitable()
182            a = yield 123
183            self.assertIs(a, None)
184            await awaitable()
185            yield 456
186            await awaitable()
187            yield 789
188
189        self.assertEqual(to_list(gen()), [123, 456, 789])
190
191    def test_async_gen_iteration_02(self):
192        async def gen():
193            await awaitable()
194            yield 123
195            await awaitable()
196
197        g = gen()
198        ai = g.__aiter__()
199
200        an = ai.__anext__()
201        self.assertEqual(an.__next__(), ('result',))
202
203        try:
204            an.__next__()
205        except StopIteration as ex:
206            self.assertEqual(ex.args[0], 123)
207        else:
208            self.fail('StopIteration was not raised')
209
210        an = ai.__anext__()
211        self.assertEqual(an.__next__(), ('result',))
212
213        try:
214            an.__next__()
215        except StopAsyncIteration as ex:
216            self.assertFalse(ex.args)
217        else:
218            self.fail('StopAsyncIteration was not raised')
219
220    def test_async_gen_exception_03(self):
221        async def gen():
222            await awaitable()
223            yield 123
224            await awaitable(throw=True)
225            yield 456
226
227        with self.assertRaises(AwaitException):
228            to_list(gen())
229
230    def test_async_gen_exception_04(self):
231        async def gen():
232            await awaitable()
233            yield 123
234            1 / 0
235
236        g = gen()
237        ai = g.__aiter__()
238        an = ai.__anext__()
239        self.assertEqual(an.__next__(), ('result',))
240
241        try:
242            an.__next__()
243        except StopIteration as ex:
244            self.assertEqual(ex.args[0], 123)
245        else:
246            self.fail('StopIteration was not raised')
247
248        with self.assertRaises(ZeroDivisionError):
249            ai.__anext__().__next__()
250
251    def test_async_gen_exception_05(self):
252        async def gen():
253            yield 123
254            raise StopAsyncIteration
255
256        with self.assertRaisesRegex(RuntimeError,
257                                    'async generator.*StopAsyncIteration'):
258            to_list(gen())
259
260    def test_async_gen_exception_06(self):
261        async def gen():
262            yield 123
263            raise StopIteration
264
265        with self.assertRaisesRegex(RuntimeError,
266                                    'async generator.*StopIteration'):
267            to_list(gen())
268
269    def test_async_gen_exception_07(self):
270        def sync_gen():
271            try:
272                yield 1
273                1 / 0
274            finally:
275                yield 2
276                yield 3
277
278            yield 100
279
280        async def async_gen():
281            try:
282                yield 1
283                1 / 0
284            finally:
285                yield 2
286                yield 3
287
288            yield 100
289
290        self.compare_generators(sync_gen(), async_gen())
291
292    def test_async_gen_exception_08(self):
293        def sync_gen():
294            try:
295                yield 1
296            finally:
297                yield 2
298                1 / 0
299                yield 3
300
301            yield 100
302
303        async def async_gen():
304            try:
305                yield 1
306                await awaitable()
307            finally:
308                await awaitable()
309                yield 2
310                1 / 0
311                yield 3
312
313            yield 100
314
315        self.compare_generators(sync_gen(), async_gen())
316
317    def test_async_gen_exception_09(self):
318        def sync_gen():
319            try:
320                yield 1
321                1 / 0
322            finally:
323                yield 2
324                yield 3
325
326            yield 100
327
328        async def async_gen():
329            try:
330                await awaitable()
331                yield 1
332                1 / 0
333            finally:
334                yield 2
335                await awaitable()
336                yield 3
337
338            yield 100
339
340        self.compare_generators(sync_gen(), async_gen())
341
342    def test_async_gen_exception_10(self):
343        async def gen():
344            yield 123
345        with self.assertRaisesRegex(TypeError,
346                                    "non-None value .* async generator"):
347            gen().__anext__().send(100)
348
349    def test_async_gen_exception_11(self):
350        def sync_gen():
351            yield 10
352            yield 20
353
354        def sync_gen_wrapper():
355            yield 1
356            sg = sync_gen()
357            sg.send(None)
358            try:
359                sg.throw(GeneratorExit())
360            except GeneratorExit:
361                yield 2
362            yield 3
363
364        async def async_gen():
365            yield 10
366            yield 20
367
368        async def async_gen_wrapper():
369            yield 1
370            asg = async_gen()
371            await asg.asend(None)
372            try:
373                await asg.athrow(GeneratorExit())
374            except GeneratorExit:
375                yield 2
376            yield 3
377
378        self.compare_generators(sync_gen_wrapper(), async_gen_wrapper())
379
380    def test_async_gen_api_01(self):
381        async def gen():
382            yield 123
383
384        g = gen()
385
386        self.assertEqual(g.__name__, 'gen')
387        g.__name__ = '123'
388        self.assertEqual(g.__name__, '123')
389
390        self.assertIn('.gen', g.__qualname__)
391        g.__qualname__ = '123'
392        self.assertEqual(g.__qualname__, '123')
393
394        self.assertIsNone(g.ag_await)
395        self.assertIsInstance(g.ag_frame, types.FrameType)
396        self.assertFalse(g.ag_running)
397        self.assertIsInstance(g.ag_code, types.CodeType)
398
399        self.assertTrue(inspect.isawaitable(g.aclose()))
400
401
402class AsyncGenAsyncioTest(unittest.TestCase):
403
404    def setUp(self):
405        self.loop = asyncio.new_event_loop()
406        asyncio.set_event_loop(None)
407
408    def tearDown(self):
409        self.loop.close()
410        self.loop = None
411        asyncio.set_event_loop_policy(None)
412
413    def check_async_iterator_anext(self, ait_class):
414        with self.subTest(anext="pure-Python"):
415            self._check_async_iterator_anext(ait_class, py_anext)
416        with self.subTest(anext="builtin"):
417            self._check_async_iterator_anext(ait_class, anext)
418
419    def _check_async_iterator_anext(self, ait_class, anext):
420        g = ait_class()
421        async def consume():
422            results = []
423            results.append(await anext(g))
424            results.append(await anext(g))
425            results.append(await anext(g, 'buckle my shoe'))
426            return results
427        res = self.loop.run_until_complete(consume())
428        self.assertEqual(res, [1, 2, 'buckle my shoe'])
429        with self.assertRaises(StopAsyncIteration):
430            self.loop.run_until_complete(consume())
431
432        async def test_2():
433            g1 = ait_class()
434            self.assertEqual(await anext(g1), 1)
435            self.assertEqual(await anext(g1), 2)
436            with self.assertRaises(StopAsyncIteration):
437                await anext(g1)
438            with self.assertRaises(StopAsyncIteration):
439                await anext(g1)
440
441            g2 = ait_class()
442            self.assertEqual(await anext(g2, "default"), 1)
443            self.assertEqual(await anext(g2, "default"), 2)
444            self.assertEqual(await anext(g2, "default"), "default")
445            self.assertEqual(await anext(g2, "default"), "default")
446
447            return "completed"
448
449        result = self.loop.run_until_complete(test_2())
450        self.assertEqual(result, "completed")
451
452        def test_send():
453            p = ait_class()
454            obj = anext(p, "completed")
455            with self.assertRaises(StopIteration):
456                with contextlib.closing(obj.__await__()) as g:
457                    g.send(None)
458
459        test_send()
460
461        async def test_throw():
462            p = ait_class()
463            obj = anext(p, "completed")
464            self.assertRaises(SyntaxError, obj.throw, SyntaxError)
465            return "completed"
466
467        result = self.loop.run_until_complete(test_throw())
468        self.assertEqual(result, "completed")
469
470    def test_async_generator_anext(self):
471        async def agen():
472            yield 1
473            yield 2
474        self.check_async_iterator_anext(agen)
475
476    def test_python_async_iterator_anext(self):
477        class MyAsyncIter:
478            """Asynchronously yield 1, then 2."""
479            def __init__(self):
480                self.yielded = 0
481            def __aiter__(self):
482                return self
483            async def __anext__(self):
484                if self.yielded >= 2:
485                    raise StopAsyncIteration()
486                else:
487                    self.yielded += 1
488                    return self.yielded
489        self.check_async_iterator_anext(MyAsyncIter)
490
491    def test_python_async_iterator_types_coroutine_anext(self):
492        import types
493        class MyAsyncIterWithTypesCoro:
494            """Asynchronously yield 1, then 2."""
495            def __init__(self):
496                self.yielded = 0
497            def __aiter__(self):
498                return self
499            @types.coroutine
500            def __anext__(self):
501                if False:
502                    yield "this is a generator-based coroutine"
503                if self.yielded >= 2:
504                    raise StopAsyncIteration()
505                else:
506                    self.yielded += 1
507                    return self.yielded
508        self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
509
510    def test_async_gen_aiter(self):
511        async def gen():
512            yield 1
513            yield 2
514        g = gen()
515        async def consume():
516            return [i async for i in aiter(g)]
517        res = self.loop.run_until_complete(consume())
518        self.assertEqual(res, [1, 2])
519
520    def test_async_gen_aiter_class(self):
521        results = []
522        class Gen:
523            async def __aiter__(self):
524                yield 1
525                yield 2
526        g = Gen()
527        async def consume():
528            ait = aiter(g)
529            while True:
530                try:
531                    results.append(await anext(ait))
532                except StopAsyncIteration:
533                    break
534        self.loop.run_until_complete(consume())
535        self.assertEqual(results, [1, 2])
536
537    def test_aiter_idempotent(self):
538        async def gen():
539            yield 1
540        applied_once = aiter(gen())
541        applied_twice = aiter(applied_once)
542        self.assertIs(applied_once, applied_twice)
543
544    def test_anext_bad_args(self):
545        async def gen():
546            yield 1
547        async def call_with_too_few_args():
548            await anext()
549        async def call_with_too_many_args():
550            await anext(gen(), 1, 3)
551        async def call_with_wrong_type_args():
552            await anext(1, gen())
553        async def call_with_kwarg():
554            await anext(aiterator=gen())
555        with self.assertRaises(TypeError):
556            self.loop.run_until_complete(call_with_too_few_args())
557        with self.assertRaises(TypeError):
558            self.loop.run_until_complete(call_with_too_many_args())
559        with self.assertRaises(TypeError):
560            self.loop.run_until_complete(call_with_wrong_type_args())
561        with self.assertRaises(TypeError):
562            self.loop.run_until_complete(call_with_kwarg())
563
564    def test_anext_bad_await(self):
565        async def bad_awaitable():
566            class BadAwaitable:
567                def __await__(self):
568                    return 42
569            class MyAsyncIter:
570                def __aiter__(self):
571                    return self
572                def __anext__(self):
573                    return BadAwaitable()
574            regex = r"__await__.*iterator"
575            awaitable = anext(MyAsyncIter(), "default")
576            with self.assertRaisesRegex(TypeError, regex):
577                await awaitable
578            awaitable = anext(MyAsyncIter())
579            with self.assertRaisesRegex(TypeError, regex):
580                await awaitable
581            return "completed"
582        result = self.loop.run_until_complete(bad_awaitable())
583        self.assertEqual(result, "completed")
584
585    async def check_anext_returning_iterator(self, aiter_class):
586        awaitable = anext(aiter_class(), "default")
587        with self.assertRaises(TypeError):
588            await awaitable
589        awaitable = anext(aiter_class())
590        with self.assertRaises(TypeError):
591            await awaitable
592        return "completed"
593
594    def test_anext_return_iterator(self):
595        class WithIterAnext:
596            def __aiter__(self):
597                return self
598            def __anext__(self):
599                return iter("abc")
600        result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
601        self.assertEqual(result, "completed")
602
603    def test_anext_return_generator(self):
604        class WithGenAnext:
605            def __aiter__(self):
606                return self
607            def __anext__(self):
608                yield
609        result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
610        self.assertEqual(result, "completed")
611
612    def test_anext_await_raises(self):
613        class RaisingAwaitable:
614            def __await__(self):
615                raise ZeroDivisionError()
616                yield
617        class WithRaisingAwaitableAnext:
618            def __aiter__(self):
619                return self
620            def __anext__(self):
621                return RaisingAwaitable()
622        async def do_test():
623            awaitable = anext(WithRaisingAwaitableAnext())
624            with self.assertRaises(ZeroDivisionError):
625                await awaitable
626            awaitable = anext(WithRaisingAwaitableAnext(), "default")
627            with self.assertRaises(ZeroDivisionError):
628                await awaitable
629            return "completed"
630        result = self.loop.run_until_complete(do_test())
631        self.assertEqual(result, "completed")
632
633    def test_anext_iter(self):
634        @types.coroutine
635        def _async_yield(v):
636            return (yield v)
637
638        class MyError(Exception):
639            pass
640
641        async def agenfn():
642            try:
643                await _async_yield(1)
644            except MyError:
645                await _async_yield(2)
646            return
647            yield
648
649        def test1(anext):
650            agen = agenfn()
651            with contextlib.closing(anext(agen, "default").__await__()) as g:
652                self.assertEqual(g.send(None), 1)
653                self.assertEqual(g.throw(MyError, MyError(), None), 2)
654                try:
655                    g.send(None)
656                except StopIteration as e:
657                    err = e
658                else:
659                    self.fail('StopIteration was not raised')
660                self.assertEqual(err.value, "default")
661
662        def test2(anext):
663            agen = agenfn()
664            with contextlib.closing(anext(agen, "default").__await__()) as g:
665                self.assertEqual(g.send(None), 1)
666                self.assertEqual(g.throw(MyError, MyError(), None), 2)
667                with self.assertRaises(MyError):
668                    g.throw(MyError, MyError(), None)
669
670        def test3(anext):
671            agen = agenfn()
672            with contextlib.closing(anext(agen, "default").__await__()) as g:
673                self.assertEqual(g.send(None), 1)
674                g.close()
675                with self.assertRaisesRegex(RuntimeError, 'cannot reuse'):
676                    self.assertEqual(g.send(None), 1)
677
678        def test4(anext):
679            @types.coroutine
680            def _async_yield(v):
681                yield v * 10
682                return (yield (v * 10 + 1))
683
684            async def agenfn():
685                try:
686                    await _async_yield(1)
687                except MyError:
688                    await _async_yield(2)
689                return
690                yield
691
692            agen = agenfn()
693            with contextlib.closing(anext(agen, "default").__await__()) as g:
694                self.assertEqual(g.send(None), 10)
695                self.assertEqual(g.throw(MyError, MyError(), None), 20)
696                with self.assertRaisesRegex(MyError, 'val'):
697                    g.throw(MyError, MyError('val'), None)
698
699        def test5(anext):
700            @types.coroutine
701            def _async_yield(v):
702                yield v * 10
703                return (yield (v * 10 + 1))
704
705            async def agenfn():
706                try:
707                    await _async_yield(1)
708                except MyError:
709                    return
710                yield 'aaa'
711
712            agen = agenfn()
713            with contextlib.closing(anext(agen, "default").__await__()) as g:
714                self.assertEqual(g.send(None), 10)
715                with self.assertRaisesRegex(StopIteration, 'default'):
716                    g.throw(MyError, MyError(), None)
717
718        def test6(anext):
719            @types.coroutine
720            def _async_yield(v):
721                yield v * 10
722                return (yield (v * 10 + 1))
723
724            async def agenfn():
725                await _async_yield(1)
726                yield 'aaa'
727
728            agen = agenfn()
729            with contextlib.closing(anext(agen, "default").__await__()) as g:
730                with self.assertRaises(MyError):
731                    g.throw(MyError, MyError(), None)
732
733        def run_test(test):
734            with self.subTest('pure-Python anext()'):
735                test(py_anext)
736            with self.subTest('builtin anext()'):
737                test(anext)
738
739        run_test(test1)
740        run_test(test2)
741        run_test(test3)
742        run_test(test4)
743        run_test(test5)
744        run_test(test6)
745
746    def test_aiter_bad_args(self):
747        async def gen():
748            yield 1
749        async def call_with_too_few_args():
750            await aiter()
751        async def call_with_too_many_args():
752            await aiter(gen(), 1)
753        async def call_with_wrong_type_arg():
754            await aiter(1)
755        with self.assertRaises(TypeError):
756            self.loop.run_until_complete(call_with_too_few_args())
757        with self.assertRaises(TypeError):
758            self.loop.run_until_complete(call_with_too_many_args())
759        with self.assertRaises(TypeError):
760            self.loop.run_until_complete(call_with_wrong_type_arg())
761
762    async def to_list(self, gen):
763        res = []
764        async for i in gen:
765            res.append(i)
766        return res
767
768    def test_async_gen_asyncio_01(self):
769        async def gen():
770            yield 1
771            await asyncio.sleep(0.01)
772            yield 2
773            await asyncio.sleep(0.01)
774            return
775            yield 3
776
777        res = self.loop.run_until_complete(self.to_list(gen()))
778        self.assertEqual(res, [1, 2])
779
780    def test_async_gen_asyncio_02(self):
781        async def gen():
782            yield 1
783            await asyncio.sleep(0.01)
784            yield 2
785            1 / 0
786            yield 3
787
788        with self.assertRaises(ZeroDivisionError):
789            self.loop.run_until_complete(self.to_list(gen()))
790
791    def test_async_gen_asyncio_03(self):
792        loop = self.loop
793
794        class Gen:
795            async def __aiter__(self):
796                yield 1
797                await asyncio.sleep(0.01)
798                yield 2
799
800        res = loop.run_until_complete(self.to_list(Gen()))
801        self.assertEqual(res, [1, 2])
802
803    def test_async_gen_asyncio_anext_04(self):
804        async def foo():
805            yield 1
806            await asyncio.sleep(0.01)
807            try:
808                yield 2
809                yield 3
810            except ZeroDivisionError:
811                yield 1000
812            await asyncio.sleep(0.01)
813            yield 4
814
815        async def run1():
816            it = foo().__aiter__()
817
818            self.assertEqual(await it.__anext__(), 1)
819            self.assertEqual(await it.__anext__(), 2)
820            self.assertEqual(await it.__anext__(), 3)
821            self.assertEqual(await it.__anext__(), 4)
822            with self.assertRaises(StopAsyncIteration):
823                await it.__anext__()
824            with self.assertRaises(StopAsyncIteration):
825                await it.__anext__()
826
827        async def run2():
828            it = foo().__aiter__()
829
830            self.assertEqual(await it.__anext__(), 1)
831            self.assertEqual(await it.__anext__(), 2)
832            try:
833                it.__anext__().throw(ZeroDivisionError)
834            except StopIteration as ex:
835                self.assertEqual(ex.args[0], 1000)
836            else:
837                self.fail('StopIteration was not raised')
838            self.assertEqual(await it.__anext__(), 4)
839            with self.assertRaises(StopAsyncIteration):
840                await it.__anext__()
841
842        self.loop.run_until_complete(run1())
843        self.loop.run_until_complete(run2())
844
845    def test_async_gen_asyncio_anext_05(self):
846        async def foo():
847            v = yield 1
848            v = yield v
849            yield v * 100
850
851        async def run():
852            it = foo().__aiter__()
853
854            try:
855                it.__anext__().send(None)
856            except StopIteration as ex:
857                self.assertEqual(ex.args[0], 1)
858            else:
859                self.fail('StopIteration was not raised')
860
861            try:
862                it.__anext__().send(10)
863            except StopIteration as ex:
864                self.assertEqual(ex.args[0], 10)
865            else:
866                self.fail('StopIteration was not raised')
867
868            try:
869                it.__anext__().send(12)
870            except StopIteration as ex:
871                self.assertEqual(ex.args[0], 1200)
872            else:
873                self.fail('StopIteration was not raised')
874
875            with self.assertRaises(StopAsyncIteration):
876                await it.__anext__()
877
878        self.loop.run_until_complete(run())
879
880    def test_async_gen_asyncio_anext_06(self):
881        DONE = 0
882
883        # test synchronous generators
884        def foo():
885            try:
886                yield
887            except:
888                pass
889        g = foo()
890        g.send(None)
891        with self.assertRaises(StopIteration):
892            g.send(None)
893
894        # now with asynchronous generators
895
896        async def gen():
897            nonlocal DONE
898            try:
899                yield
900            except:
901                pass
902            DONE = 1
903
904        async def run():
905            nonlocal DONE
906            g = gen()
907            await g.asend(None)
908            with self.assertRaises(StopAsyncIteration):
909                await g.asend(None)
910            DONE += 10
911
912        self.loop.run_until_complete(run())
913        self.assertEqual(DONE, 11)
914
915    def test_async_gen_asyncio_anext_tuple(self):
916        async def foo():
917            try:
918                yield (1,)
919            except ZeroDivisionError:
920                yield (2,)
921
922        async def run():
923            it = foo().__aiter__()
924
925            self.assertEqual(await it.__anext__(), (1,))
926            with self.assertRaises(StopIteration) as cm:
927                it.__anext__().throw(ZeroDivisionError)
928            self.assertEqual(cm.exception.args[0], (2,))
929            with self.assertRaises(StopAsyncIteration):
930                await it.__anext__()
931
932        self.loop.run_until_complete(run())
933
934    def test_async_gen_asyncio_anext_stopiteration(self):
935        async def foo():
936            try:
937                yield StopIteration(1)
938            except ZeroDivisionError:
939                yield StopIteration(3)
940
941        async def run():
942            it = foo().__aiter__()
943
944            v = await it.__anext__()
945            self.assertIsInstance(v, StopIteration)
946            self.assertEqual(v.value, 1)
947            with self.assertRaises(StopIteration) as cm:
948                it.__anext__().throw(ZeroDivisionError)
949            v = cm.exception.args[0]
950            self.assertIsInstance(v, StopIteration)
951            self.assertEqual(v.value, 3)
952            with self.assertRaises(StopAsyncIteration):
953                await it.__anext__()
954
955        self.loop.run_until_complete(run())
956
957    def test_async_gen_asyncio_aclose_06(self):
958        async def foo():
959            try:
960                yield 1
961                1 / 0
962            finally:
963                await asyncio.sleep(0.01)
964                yield 12
965
966        async def run():
967            gen = foo()
968            it = gen.__aiter__()
969            await it.__anext__()
970            await gen.aclose()
971
972        with self.assertRaisesRegex(
973                RuntimeError,
974                "async generator ignored GeneratorExit"):
975            self.loop.run_until_complete(run())
976
977    def test_async_gen_asyncio_aclose_07(self):
978        DONE = 0
979
980        async def foo():
981            nonlocal DONE
982            try:
983                yield 1
984                1 / 0
985            finally:
986                await asyncio.sleep(0.01)
987                await asyncio.sleep(0.01)
988                DONE += 1
989            DONE += 1000
990
991        async def run():
992            gen = foo()
993            it = gen.__aiter__()
994            await it.__anext__()
995            await gen.aclose()
996
997        self.loop.run_until_complete(run())
998        self.assertEqual(DONE, 1)
999
1000    def test_async_gen_asyncio_aclose_08(self):
1001        DONE = 0
1002
1003        fut = asyncio.Future(loop=self.loop)
1004
1005        async def foo():
1006            nonlocal DONE
1007            try:
1008                yield 1
1009                await fut
1010                DONE += 1000
1011                yield 2
1012            finally:
1013                await asyncio.sleep(0.01)
1014                await asyncio.sleep(0.01)
1015                DONE += 1
1016            DONE += 1000
1017
1018        async def run():
1019            gen = foo()
1020            it = gen.__aiter__()
1021            self.assertEqual(await it.__anext__(), 1)
1022            await gen.aclose()
1023
1024        self.loop.run_until_complete(run())
1025        self.assertEqual(DONE, 1)
1026
1027        # Silence ResourceWarnings
1028        fut.cancel()
1029        self.loop.run_until_complete(asyncio.sleep(0.01))
1030
1031    def test_async_gen_asyncio_gc_aclose_09(self):
1032        DONE = 0
1033
1034        async def gen():
1035            nonlocal DONE
1036            try:
1037                while True:
1038                    yield 1
1039            finally:
1040                await asyncio.sleep(0.01)
1041                await asyncio.sleep(0.01)
1042                DONE = 1
1043
1044        async def run():
1045            g = gen()
1046            await g.__anext__()
1047            await g.__anext__()
1048            del g
1049            gc_collect()  # For PyPy or other GCs.
1050
1051            await asyncio.sleep(0.1)
1052
1053        self.loop.run_until_complete(run())
1054        self.assertEqual(DONE, 1)
1055
1056    def test_async_gen_asyncio_aclose_10(self):
1057        DONE = 0
1058
1059        # test synchronous generators
1060        def foo():
1061            try:
1062                yield
1063            except:
1064                pass
1065        g = foo()
1066        g.send(None)
1067        g.close()
1068
1069        # now with asynchronous generators
1070
1071        async def gen():
1072            nonlocal DONE
1073            try:
1074                yield
1075            except:
1076                pass
1077            DONE = 1
1078
1079        async def run():
1080            nonlocal DONE
1081            g = gen()
1082            await g.asend(None)
1083            await g.aclose()
1084            DONE += 10
1085
1086        self.loop.run_until_complete(run())
1087        self.assertEqual(DONE, 11)
1088
1089    def test_async_gen_asyncio_aclose_11(self):
1090        DONE = 0
1091
1092        # test synchronous generators
1093        def foo():
1094            try:
1095                yield
1096            except:
1097                pass
1098            yield
1099        g = foo()
1100        g.send(None)
1101        with self.assertRaisesRegex(RuntimeError, 'ignored GeneratorExit'):
1102            g.close()
1103
1104        # now with asynchronous generators
1105
1106        async def gen():
1107            nonlocal DONE
1108            try:
1109                yield
1110            except:
1111                pass
1112            yield
1113            DONE += 1
1114
1115        async def run():
1116            nonlocal DONE
1117            g = gen()
1118            await g.asend(None)
1119            with self.assertRaisesRegex(RuntimeError, 'ignored GeneratorExit'):
1120                await g.aclose()
1121            DONE += 10
1122
1123        self.loop.run_until_complete(run())
1124        self.assertEqual(DONE, 10)
1125
1126    def test_async_gen_asyncio_aclose_12(self):
1127        DONE = 0
1128
1129        async def target():
1130            await asyncio.sleep(0.01)
1131            1 / 0
1132
1133        async def foo():
1134            nonlocal DONE
1135            task = asyncio.create_task(target())
1136            try:
1137                yield 1
1138            finally:
1139                try:
1140                    await task
1141                except ZeroDivisionError:
1142                    DONE = 1
1143
1144        async def run():
1145            gen = foo()
1146            it = gen.__aiter__()
1147            await it.__anext__()
1148            await gen.aclose()
1149
1150        self.loop.run_until_complete(run())
1151        self.assertEqual(DONE, 1)
1152
1153    def test_async_gen_asyncio_asend_01(self):
1154        DONE = 0
1155
1156        # Sanity check:
1157        def sgen():
1158            v = yield 1
1159            yield v * 2
1160        sg = sgen()
1161        v = sg.send(None)
1162        self.assertEqual(v, 1)
1163        v = sg.send(100)
1164        self.assertEqual(v, 200)
1165
1166        async def gen():
1167            nonlocal DONE
1168            try:
1169                await asyncio.sleep(0.01)
1170                v = yield 1
1171                await asyncio.sleep(0.01)
1172                yield v * 2
1173                await asyncio.sleep(0.01)
1174                return
1175            finally:
1176                await asyncio.sleep(0.01)
1177                await asyncio.sleep(0.01)
1178                DONE = 1
1179
1180        async def run():
1181            g = gen()
1182
1183            v = await g.asend(None)
1184            self.assertEqual(v, 1)
1185
1186            v = await g.asend(100)
1187            self.assertEqual(v, 200)
1188
1189            with self.assertRaises(StopAsyncIteration):
1190                await g.asend(None)
1191
1192        self.loop.run_until_complete(run())
1193        self.assertEqual(DONE, 1)
1194
1195    def test_async_gen_asyncio_asend_02(self):
1196        DONE = 0
1197
1198        async def sleep_n_crash(delay):
1199            await asyncio.sleep(delay)
1200            1 / 0
1201
1202        async def gen():
1203            nonlocal DONE
1204            try:
1205                await asyncio.sleep(0.01)
1206                v = yield 1
1207                await sleep_n_crash(0.01)
1208                DONE += 1000
1209                yield v * 2
1210            finally:
1211                await asyncio.sleep(0.01)
1212                await asyncio.sleep(0.01)
1213                DONE = 1
1214
1215        async def run():
1216            g = gen()
1217
1218            v = await g.asend(None)
1219            self.assertEqual(v, 1)
1220
1221            await g.asend(100)
1222
1223        with self.assertRaises(ZeroDivisionError):
1224            self.loop.run_until_complete(run())
1225        self.assertEqual(DONE, 1)
1226
1227    def test_async_gen_asyncio_asend_03(self):
1228        DONE = 0
1229
1230        async def sleep_n_crash(delay):
1231            fut = asyncio.ensure_future(asyncio.sleep(delay),
1232                                        loop=self.loop)
1233            self.loop.call_later(delay / 2, lambda: fut.cancel())
1234            return await fut
1235
1236        async def gen():
1237            nonlocal DONE
1238            try:
1239                await asyncio.sleep(0.01)
1240                v = yield 1
1241                await sleep_n_crash(0.01)
1242                DONE += 1000
1243                yield v * 2
1244            finally:
1245                await asyncio.sleep(0.01)
1246                await asyncio.sleep(0.01)
1247                DONE = 1
1248
1249        async def run():
1250            g = gen()
1251
1252            v = await g.asend(None)
1253            self.assertEqual(v, 1)
1254
1255            await g.asend(100)
1256
1257        with self.assertRaises(asyncio.CancelledError):
1258            self.loop.run_until_complete(run())
1259        self.assertEqual(DONE, 1)
1260
1261    def test_async_gen_asyncio_athrow_01(self):
1262        DONE = 0
1263
1264        class FooEr(Exception):
1265            pass
1266
1267        # Sanity check:
1268        def sgen():
1269            try:
1270                v = yield 1
1271            except FooEr:
1272                v = 1000
1273            yield v * 2
1274        sg = sgen()
1275        v = sg.send(None)
1276        self.assertEqual(v, 1)
1277        v = sg.throw(FooEr)
1278        self.assertEqual(v, 2000)
1279        with self.assertRaises(StopIteration):
1280            sg.send(None)
1281
1282        async def gen():
1283            nonlocal DONE
1284            try:
1285                await asyncio.sleep(0.01)
1286                try:
1287                    v = yield 1
1288                except FooEr:
1289                    v = 1000
1290                    await asyncio.sleep(0.01)
1291                yield v * 2
1292                await asyncio.sleep(0.01)
1293                # return
1294            finally:
1295                await asyncio.sleep(0.01)
1296                await asyncio.sleep(0.01)
1297                DONE = 1
1298
1299        async def run():
1300            g = gen()
1301
1302            v = await g.asend(None)
1303            self.assertEqual(v, 1)
1304
1305            v = await g.athrow(FooEr)
1306            self.assertEqual(v, 2000)
1307
1308            with self.assertRaises(StopAsyncIteration):
1309                await g.asend(None)
1310
1311        self.loop.run_until_complete(run())
1312        self.assertEqual(DONE, 1)
1313
1314    def test_async_gen_asyncio_athrow_02(self):
1315        DONE = 0
1316
1317        class FooEr(Exception):
1318            pass
1319
1320        async def sleep_n_crash(delay):
1321            fut = asyncio.ensure_future(asyncio.sleep(delay),
1322                                        loop=self.loop)
1323            self.loop.call_later(delay / 2, lambda: fut.cancel())
1324            return await fut
1325
1326        async def gen():
1327            nonlocal DONE
1328            try:
1329                await asyncio.sleep(0.01)
1330                try:
1331                    v = yield 1
1332                except FooEr:
1333                    await sleep_n_crash(0.01)
1334                yield v * 2
1335                await asyncio.sleep(0.01)
1336                # return
1337            finally:
1338                await asyncio.sleep(0.01)
1339                await asyncio.sleep(0.01)
1340                DONE = 1
1341
1342        async def run():
1343            g = gen()
1344
1345            v = await g.asend(None)
1346            self.assertEqual(v, 1)
1347
1348            try:
1349                await g.athrow(FooEr)
1350            except asyncio.CancelledError:
1351                self.assertEqual(DONE, 1)
1352                raise
1353            else:
1354                self.fail('CancelledError was not raised')
1355
1356        with self.assertRaises(asyncio.CancelledError):
1357            self.loop.run_until_complete(run())
1358        self.assertEqual(DONE, 1)
1359
1360    def test_async_gen_asyncio_athrow_03(self):
1361        DONE = 0
1362
1363        # test synchronous generators
1364        def foo():
1365            try:
1366                yield
1367            except:
1368                pass
1369        g = foo()
1370        g.send(None)
1371        with self.assertRaises(StopIteration):
1372            g.throw(ValueError)
1373
1374        # now with asynchronous generators
1375
1376        async def gen():
1377            nonlocal DONE
1378            try:
1379                yield
1380            except:
1381                pass
1382            DONE = 1
1383
1384        async def run():
1385            nonlocal DONE
1386            g = gen()
1387            await g.asend(None)
1388            with self.assertRaises(StopAsyncIteration):
1389                await g.athrow(ValueError)
1390            DONE += 10
1391
1392        self.loop.run_until_complete(run())
1393        self.assertEqual(DONE, 11)
1394
1395    def test_async_gen_asyncio_athrow_tuple(self):
1396        async def gen():
1397            try:
1398                yield 1
1399            except ZeroDivisionError:
1400                yield (2,)
1401
1402        async def run():
1403            g = gen()
1404            v = await g.asend(None)
1405            self.assertEqual(v, 1)
1406            v = await g.athrow(ZeroDivisionError)
1407            self.assertEqual(v, (2,))
1408            with self.assertRaises(StopAsyncIteration):
1409                await g.asend(None)
1410
1411        self.loop.run_until_complete(run())
1412
1413    def test_async_gen_asyncio_athrow_stopiteration(self):
1414        async def gen():
1415            try:
1416                yield 1
1417            except ZeroDivisionError:
1418                yield StopIteration(2)
1419
1420        async def run():
1421            g = gen()
1422            v = await g.asend(None)
1423            self.assertEqual(v, 1)
1424            v = await g.athrow(ZeroDivisionError)
1425            self.assertIsInstance(v, StopIteration)
1426            self.assertEqual(v.value, 2)
1427            with self.assertRaises(StopAsyncIteration):
1428                await g.asend(None)
1429
1430        self.loop.run_until_complete(run())
1431
1432    def test_async_gen_asyncio_shutdown_01(self):
1433        finalized = 0
1434
1435        async def waiter(timeout):
1436            nonlocal finalized
1437            try:
1438                await asyncio.sleep(timeout)
1439                yield 1
1440            finally:
1441                await asyncio.sleep(0)
1442                finalized += 1
1443
1444        async def wait():
1445            async for _ in waiter(1):
1446                pass
1447
1448        t1 = self.loop.create_task(wait())
1449        t2 = self.loop.create_task(wait())
1450
1451        self.loop.run_until_complete(asyncio.sleep(0.1))
1452
1453        # Silence warnings
1454        t1.cancel()
1455        t2.cancel()
1456
1457        with self.assertRaises(asyncio.CancelledError):
1458            self.loop.run_until_complete(t1)
1459        with self.assertRaises(asyncio.CancelledError):
1460            self.loop.run_until_complete(t2)
1461
1462        self.loop.run_until_complete(self.loop.shutdown_asyncgens())
1463
1464        self.assertEqual(finalized, 2)
1465
1466    def test_async_gen_asyncio_shutdown_02(self):
1467        messages = []
1468
1469        def exception_handler(loop, context):
1470            messages.append(context)
1471
1472        async def async_iterate():
1473            yield 1
1474            yield 2
1475
1476        it = async_iterate()
1477        async def main():
1478            loop = asyncio.get_running_loop()
1479            loop.set_exception_handler(exception_handler)
1480
1481            async for i in it:
1482                break
1483
1484        asyncio.run(main())
1485
1486        self.assertEqual(messages, [])
1487
1488    def test_async_gen_asyncio_shutdown_exception_01(self):
1489        messages = []
1490
1491        def exception_handler(loop, context):
1492            messages.append(context)
1493
1494        async def async_iterate():
1495            try:
1496                yield 1
1497                yield 2
1498            finally:
1499                1/0
1500
1501        it = async_iterate()
1502        async def main():
1503            loop = asyncio.get_running_loop()
1504            loop.set_exception_handler(exception_handler)
1505
1506            async for i in it:
1507                break
1508
1509        asyncio.run(main())
1510
1511        message, = messages
1512        self.assertEqual(message['asyncgen'], it)
1513        self.assertIsInstance(message['exception'], ZeroDivisionError)
1514        self.assertIn('an error occurred during closing of asynchronous generator',
1515                      message['message'])
1516
1517    def test_async_gen_asyncio_shutdown_exception_02(self):
1518        messages = []
1519
1520        def exception_handler(loop, context):
1521            messages.append(context)
1522
1523        async def async_iterate():
1524            try:
1525                yield 1
1526                yield 2
1527            finally:
1528                1/0
1529
1530        async def main():
1531            loop = asyncio.get_running_loop()
1532            loop.set_exception_handler(exception_handler)
1533
1534            async for i in async_iterate():
1535                break
1536            gc_collect()
1537
1538        asyncio.run(main())
1539
1540        message, = messages
1541        self.assertIsInstance(message['exception'], ZeroDivisionError)
1542        self.assertIn('unhandled exception during asyncio.run() shutdown',
1543                      message['message'])
1544
1545    def test_async_gen_expression_01(self):
1546        async def arange(n):
1547            for i in range(n):
1548                await asyncio.sleep(0.01)
1549                yield i
1550
1551        def make_arange(n):
1552            # This syntax is legal starting with Python 3.7
1553            return (i * 2 async for i in arange(n))
1554
1555        async def run():
1556            return [i async for i in make_arange(10)]
1557
1558        res = self.loop.run_until_complete(run())
1559        self.assertEqual(res, [i * 2 for i in range(10)])
1560
1561    def test_async_gen_expression_02(self):
1562        async def wrap(n):
1563            await asyncio.sleep(0.01)
1564            return n
1565
1566        def make_arange(n):
1567            # This syntax is legal starting with Python 3.7
1568            return (i * 2 for i in range(n) if await wrap(i))
1569
1570        async def run():
1571            return [i async for i in make_arange(10)]
1572
1573        res = self.loop.run_until_complete(run())
1574        self.assertEqual(res, [i * 2 for i in range(1, 10)])
1575
1576    def test_asyncgen_nonstarted_hooks_are_cancellable(self):
1577        # See https://bugs.python.org/issue38013
1578        messages = []
1579
1580        def exception_handler(loop, context):
1581            messages.append(context)
1582
1583        async def async_iterate():
1584            yield 1
1585            yield 2
1586
1587        async def main():
1588            loop = asyncio.get_running_loop()
1589            loop.set_exception_handler(exception_handler)
1590
1591            async for i in async_iterate():
1592                break
1593
1594        asyncio.run(main())
1595
1596        self.assertEqual([], messages)
1597
1598    def test_async_gen_await_same_anext_coro_twice(self):
1599        async def async_iterate():
1600            yield 1
1601            yield 2
1602
1603        async def run():
1604            it = async_iterate()
1605            nxt = it.__anext__()
1606            await nxt
1607            with self.assertRaisesRegex(
1608                    RuntimeError,
1609                    r"cannot reuse already awaited __anext__\(\)/asend\(\)"
1610            ):
1611                await nxt
1612
1613            await it.aclose()  # prevent unfinished iterator warning
1614
1615        self.loop.run_until_complete(run())
1616
1617    def test_async_gen_await_same_aclose_coro_twice(self):
1618        async def async_iterate():
1619            yield 1
1620            yield 2
1621
1622        async def run():
1623            it = async_iterate()
1624            nxt = it.aclose()
1625            await nxt
1626            with self.assertRaisesRegex(
1627                    RuntimeError,
1628                    r"cannot reuse already awaited aclose\(\)/athrow\(\)"
1629            ):
1630                await nxt
1631
1632        self.loop.run_until_complete(run())
1633
1634    def test_async_gen_aclose_twice_with_different_coros(self):
1635        # Regression test for https://bugs.python.org/issue39606
1636        async def async_iterate():
1637            yield 1
1638            yield 2
1639
1640        async def run():
1641            it = async_iterate()
1642            await it.aclose()
1643            await it.aclose()
1644
1645        self.loop.run_until_complete(run())
1646
1647    def test_async_gen_aclose_after_exhaustion(self):
1648        # Regression test for https://bugs.python.org/issue39606
1649        async def async_iterate():
1650            yield 1
1651            yield 2
1652
1653        async def run():
1654            it = async_iterate()
1655            async for _ in it:
1656                pass
1657            await it.aclose()
1658
1659        self.loop.run_until_complete(run())
1660
1661    def test_async_gen_aclose_compatible_with_get_stack(self):
1662        async def async_generator():
1663            yield object()
1664
1665        async def run():
1666            ag = async_generator()
1667            asyncio.create_task(ag.aclose())
1668            tasks = asyncio.all_tasks()
1669            for task in tasks:
1670                # No AttributeError raised
1671                task.get_stack()
1672
1673        self.loop.run_until_complete(run())
1674
1675
1676if __name__ == "__main__":
1677    unittest.main()
1678