1# Adapted with permission from the EdgeDB project;
2# license: PSFL.
3
4
5import asyncio
6import contextvars
7import contextlib
8from asyncio import taskgroups
9import unittest
10
11
12# To prevent a warning "test altered the execution environment"
13def tearDownModule():
14    asyncio.set_event_loop_policy(None)
15
16
17class MyExc(Exception):
18    pass
19
20
21class MyBaseExc(BaseException):
22    pass
23
24
25def get_error_types(eg):
26    return {type(exc) for exc in eg.exceptions}
27
28
29class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
30
31    async def test_taskgroup_01(self):
32
33        async def foo1():
34            await asyncio.sleep(0.1)
35            return 42
36
37        async def foo2():
38            await asyncio.sleep(0.2)
39            return 11
40
41        async with taskgroups.TaskGroup() as g:
42            t1 = g.create_task(foo1())
43            t2 = g.create_task(foo2())
44
45        self.assertEqual(t1.result(), 42)
46        self.assertEqual(t2.result(), 11)
47
48    async def test_taskgroup_02(self):
49
50        async def foo1():
51            await asyncio.sleep(0.1)
52            return 42
53
54        async def foo2():
55            await asyncio.sleep(0.2)
56            return 11
57
58        async with taskgroups.TaskGroup() as g:
59            t1 = g.create_task(foo1())
60            await asyncio.sleep(0.15)
61            t2 = g.create_task(foo2())
62
63        self.assertEqual(t1.result(), 42)
64        self.assertEqual(t2.result(), 11)
65
66    async def test_taskgroup_03(self):
67
68        async def foo1():
69            await asyncio.sleep(1)
70            return 42
71
72        async def foo2():
73            await asyncio.sleep(0.2)
74            return 11
75
76        async with taskgroups.TaskGroup() as g:
77            t1 = g.create_task(foo1())
78            await asyncio.sleep(0.15)
79            # cancel t1 explicitly, i.e. everything should continue
80            # working as expected.
81            t1.cancel()
82
83            t2 = g.create_task(foo2())
84
85        self.assertTrue(t1.cancelled())
86        self.assertEqual(t2.result(), 11)
87
88    async def test_taskgroup_04(self):
89
90        NUM = 0
91        t2_cancel = False
92        t2 = None
93
94        async def foo1():
95            await asyncio.sleep(0.1)
96            1 / 0
97
98        async def foo2():
99            nonlocal NUM, t2_cancel
100            try:
101                await asyncio.sleep(1)
102            except asyncio.CancelledError:
103                t2_cancel = True
104                raise
105            NUM += 1
106
107        async def runner():
108            nonlocal NUM, t2
109
110            async with taskgroups.TaskGroup() as g:
111                g.create_task(foo1())
112                t2 = g.create_task(foo2())
113
114            NUM += 10
115
116        with self.assertRaises(ExceptionGroup) as cm:
117            await asyncio.create_task(runner())
118
119        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
120
121        self.assertEqual(NUM, 0)
122        self.assertTrue(t2_cancel)
123        self.assertTrue(t2.cancelled())
124
125    async def test_cancel_children_on_child_error(self):
126        # When a child task raises an error, the rest of the children
127        # are cancelled and the errors are gathered into an EG.
128
129        NUM = 0
130        t2_cancel = False
131        runner_cancel = False
132
133        async def foo1():
134            await asyncio.sleep(0.1)
135            1 / 0
136
137        async def foo2():
138            nonlocal NUM, t2_cancel
139            try:
140                await asyncio.sleep(5)
141            except asyncio.CancelledError:
142                t2_cancel = True
143                raise
144            NUM += 1
145
146        async def runner():
147            nonlocal NUM, runner_cancel
148
149            async with taskgroups.TaskGroup() as g:
150                g.create_task(foo1())
151                g.create_task(foo1())
152                g.create_task(foo1())
153                g.create_task(foo2())
154                try:
155                    await asyncio.sleep(10)
156                except asyncio.CancelledError:
157                    runner_cancel = True
158                    raise
159
160            NUM += 10
161
162        # The 3 foo1 sub tasks can be racy when the host is busy - if the
163        # cancellation happens in the middle, we'll see partial sub errors here
164        with self.assertRaises(ExceptionGroup) as cm:
165            await asyncio.create_task(runner())
166
167        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
168        self.assertEqual(NUM, 0)
169        self.assertTrue(t2_cancel)
170        self.assertTrue(runner_cancel)
171
172    async def test_cancellation(self):
173
174        NUM = 0
175
176        async def foo():
177            nonlocal NUM
178            try:
179                await asyncio.sleep(5)
180            except asyncio.CancelledError:
181                NUM += 1
182                raise
183
184        async def runner():
185            async with taskgroups.TaskGroup() as g:
186                for _ in range(5):
187                    g.create_task(foo())
188
189        r = asyncio.create_task(runner())
190        await asyncio.sleep(0.1)
191
192        self.assertFalse(r.done())
193        r.cancel()
194        with self.assertRaises(asyncio.CancelledError) as cm:
195            await r
196
197        self.assertEqual(NUM, 5)
198
199    async def test_taskgroup_07(self):
200
201        NUM = 0
202
203        async def foo():
204            nonlocal NUM
205            try:
206                await asyncio.sleep(5)
207            except asyncio.CancelledError:
208                NUM += 1
209                raise
210
211        async def runner():
212            nonlocal NUM
213            async with taskgroups.TaskGroup() as g:
214                for _ in range(5):
215                    g.create_task(foo())
216
217                try:
218                    await asyncio.sleep(10)
219                except asyncio.CancelledError:
220                    NUM += 10
221                    raise
222
223        r = asyncio.create_task(runner())
224        await asyncio.sleep(0.1)
225
226        self.assertFalse(r.done())
227        r.cancel()
228        with self.assertRaises(asyncio.CancelledError):
229            await r
230
231        self.assertEqual(NUM, 15)
232
233    async def test_taskgroup_08(self):
234
235        async def foo():
236            try:
237                await asyncio.sleep(10)
238            finally:
239                1 / 0
240
241        async def runner():
242            async with taskgroups.TaskGroup() as g:
243                for _ in range(5):
244                    g.create_task(foo())
245
246                await asyncio.sleep(10)
247
248        r = asyncio.create_task(runner())
249        await asyncio.sleep(0.1)
250
251        self.assertFalse(r.done())
252        r.cancel()
253        with self.assertRaises(ExceptionGroup) as cm:
254            await r
255        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
256
257    async def test_taskgroup_09(self):
258
259        t1 = t2 = None
260
261        async def foo1():
262            await asyncio.sleep(1)
263            return 42
264
265        async def foo2():
266            await asyncio.sleep(2)
267            return 11
268
269        async def runner():
270            nonlocal t1, t2
271            async with taskgroups.TaskGroup() as g:
272                t1 = g.create_task(foo1())
273                t2 = g.create_task(foo2())
274                await asyncio.sleep(0.1)
275                1 / 0
276
277        try:
278            await runner()
279        except ExceptionGroup as t:
280            self.assertEqual(get_error_types(t), {ZeroDivisionError})
281        else:
282            self.fail('ExceptionGroup was not raised')
283
284        self.assertTrue(t1.cancelled())
285        self.assertTrue(t2.cancelled())
286
287    async def test_taskgroup_10(self):
288
289        t1 = t2 = None
290
291        async def foo1():
292            await asyncio.sleep(1)
293            return 42
294
295        async def foo2():
296            await asyncio.sleep(2)
297            return 11
298
299        async def runner():
300            nonlocal t1, t2
301            async with taskgroups.TaskGroup() as g:
302                t1 = g.create_task(foo1())
303                t2 = g.create_task(foo2())
304                1 / 0
305
306        try:
307            await runner()
308        except ExceptionGroup as t:
309            self.assertEqual(get_error_types(t), {ZeroDivisionError})
310        else:
311            self.fail('ExceptionGroup was not raised')
312
313        self.assertTrue(t1.cancelled())
314        self.assertTrue(t2.cancelled())
315
316    async def test_taskgroup_11(self):
317
318        async def foo():
319            try:
320                await asyncio.sleep(10)
321            finally:
322                1 / 0
323
324        async def runner():
325            async with taskgroups.TaskGroup():
326                async with taskgroups.TaskGroup() as g2:
327                    for _ in range(5):
328                        g2.create_task(foo())
329
330                    await asyncio.sleep(10)
331
332        r = asyncio.create_task(runner())
333        await asyncio.sleep(0.1)
334
335        self.assertFalse(r.done())
336        r.cancel()
337        with self.assertRaises(ExceptionGroup) as cm:
338            await r
339
340        self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
341        self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
342
343    async def test_taskgroup_12(self):
344
345        async def foo():
346            try:
347                await asyncio.sleep(10)
348            finally:
349                1 / 0
350
351        async def runner():
352            async with taskgroups.TaskGroup() as g1:
353                g1.create_task(asyncio.sleep(10))
354
355                async with taskgroups.TaskGroup() as g2:
356                    for _ in range(5):
357                        g2.create_task(foo())
358
359                    await asyncio.sleep(10)
360
361        r = asyncio.create_task(runner())
362        await asyncio.sleep(0.1)
363
364        self.assertFalse(r.done())
365        r.cancel()
366        with self.assertRaises(ExceptionGroup) as cm:
367            await r
368
369        self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
370        self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError})
371
372    async def test_taskgroup_13(self):
373
374        async def crash_after(t):
375            await asyncio.sleep(t)
376            raise ValueError(t)
377
378        async def runner():
379            async with taskgroups.TaskGroup() as g1:
380                g1.create_task(crash_after(0.1))
381
382                async with taskgroups.TaskGroup() as g2:
383                    g2.create_task(crash_after(10))
384
385        r = asyncio.create_task(runner())
386        with self.assertRaises(ExceptionGroup) as cm:
387            await r
388
389        self.assertEqual(get_error_types(cm.exception), {ValueError})
390
391    async def test_taskgroup_14(self):
392
393        async def crash_after(t):
394            await asyncio.sleep(t)
395            raise ValueError(t)
396
397        async def runner():
398            async with taskgroups.TaskGroup() as g1:
399                g1.create_task(crash_after(10))
400
401                async with taskgroups.TaskGroup() as g2:
402                    g2.create_task(crash_after(0.1))
403
404        r = asyncio.create_task(runner())
405        with self.assertRaises(ExceptionGroup) as cm:
406            await r
407
408        self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
409        self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
410
411    async def test_taskgroup_15(self):
412
413        async def crash_soon():
414            await asyncio.sleep(0.3)
415            1 / 0
416
417        async def runner():
418            async with taskgroups.TaskGroup() as g1:
419                g1.create_task(crash_soon())
420                try:
421                    await asyncio.sleep(10)
422                except asyncio.CancelledError:
423                    await asyncio.sleep(0.5)
424                    raise
425
426        r = asyncio.create_task(runner())
427        await asyncio.sleep(0.1)
428
429        self.assertFalse(r.done())
430        r.cancel()
431        with self.assertRaises(ExceptionGroup) as cm:
432            await r
433        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
434
435    async def test_taskgroup_16(self):
436
437        async def crash_soon():
438            await asyncio.sleep(0.3)
439            1 / 0
440
441        async def nested_runner():
442            async with taskgroups.TaskGroup() as g1:
443                g1.create_task(crash_soon())
444                try:
445                    await asyncio.sleep(10)
446                except asyncio.CancelledError:
447                    await asyncio.sleep(0.5)
448                    raise
449
450        async def runner():
451            t = asyncio.create_task(nested_runner())
452            await t
453
454        r = asyncio.create_task(runner())
455        await asyncio.sleep(0.1)
456
457        self.assertFalse(r.done())
458        r.cancel()
459        with self.assertRaises(ExceptionGroup) as cm:
460            await r
461        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
462
463    async def test_taskgroup_17(self):
464        NUM = 0
465
466        async def runner():
467            nonlocal NUM
468            async with taskgroups.TaskGroup():
469                try:
470                    await asyncio.sleep(10)
471                except asyncio.CancelledError:
472                    NUM += 10
473                    raise
474
475        r = asyncio.create_task(runner())
476        await asyncio.sleep(0.1)
477
478        self.assertFalse(r.done())
479        r.cancel()
480        with self.assertRaises(asyncio.CancelledError):
481            await r
482
483        self.assertEqual(NUM, 10)
484
485    async def test_taskgroup_18(self):
486        NUM = 0
487
488        async def runner():
489            nonlocal NUM
490            async with taskgroups.TaskGroup():
491                try:
492                    await asyncio.sleep(10)
493                except asyncio.CancelledError:
494                    NUM += 10
495                    # This isn't a good idea, but we have to support
496                    # this weird case.
497                    raise MyExc
498
499        r = asyncio.create_task(runner())
500        await asyncio.sleep(0.1)
501
502        self.assertFalse(r.done())
503        r.cancel()
504
505        try:
506            await r
507        except ExceptionGroup as t:
508            self.assertEqual(get_error_types(t),{MyExc})
509        else:
510            self.fail('ExceptionGroup was not raised')
511
512        self.assertEqual(NUM, 10)
513
514    async def test_taskgroup_19(self):
515        async def crash_soon():
516            await asyncio.sleep(0.1)
517            1 / 0
518
519        async def nested():
520            try:
521                await asyncio.sleep(10)
522            finally:
523                raise MyExc
524
525        async def runner():
526            async with taskgroups.TaskGroup() as g:
527                g.create_task(crash_soon())
528                await nested()
529
530        r = asyncio.create_task(runner())
531        try:
532            await r
533        except ExceptionGroup as t:
534            self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
535        else:
536            self.fail('TasgGroupError was not raised')
537
538    async def test_taskgroup_20(self):
539        async def crash_soon():
540            await asyncio.sleep(0.1)
541            1 / 0
542
543        async def nested():
544            try:
545                await asyncio.sleep(10)
546            finally:
547                raise KeyboardInterrupt
548
549        async def runner():
550            async with taskgroups.TaskGroup() as g:
551                g.create_task(crash_soon())
552                await nested()
553
554        with self.assertRaises(KeyboardInterrupt):
555            await runner()
556
557    async def test_taskgroup_20a(self):
558        async def crash_soon():
559            await asyncio.sleep(0.1)
560            1 / 0
561
562        async def nested():
563            try:
564                await asyncio.sleep(10)
565            finally:
566                raise MyBaseExc
567
568        async def runner():
569            async with taskgroups.TaskGroup() as g:
570                g.create_task(crash_soon())
571                await nested()
572
573        with self.assertRaises(BaseExceptionGroup) as cm:
574            await runner()
575
576        self.assertEqual(
577            get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
578        )
579
580    async def _test_taskgroup_21(self):
581        # This test doesn't work as asyncio, currently, doesn't
582        # correctly propagate KeyboardInterrupt (or SystemExit) --
583        # those cause the event loop itself to crash.
584        # (Compare to the previous (passing) test -- that one raises
585        # a plain exception but raises KeyboardInterrupt in nested();
586        # this test does it the other way around.)
587
588        async def crash_soon():
589            await asyncio.sleep(0.1)
590            raise KeyboardInterrupt
591
592        async def nested():
593            try:
594                await asyncio.sleep(10)
595            finally:
596                raise TypeError
597
598        async def runner():
599            async with taskgroups.TaskGroup() as g:
600                g.create_task(crash_soon())
601                await nested()
602
603        with self.assertRaises(KeyboardInterrupt):
604            await runner()
605
606    async def test_taskgroup_21a(self):
607
608        async def crash_soon():
609            await asyncio.sleep(0.1)
610            raise MyBaseExc
611
612        async def nested():
613            try:
614                await asyncio.sleep(10)
615            finally:
616                raise TypeError
617
618        async def runner():
619            async with taskgroups.TaskGroup() as g:
620                g.create_task(crash_soon())
621                await nested()
622
623        with self.assertRaises(BaseExceptionGroup) as cm:
624            await runner()
625
626        self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
627
628    async def test_taskgroup_22(self):
629
630        async def foo1():
631            await asyncio.sleep(1)
632            return 42
633
634        async def foo2():
635            await asyncio.sleep(2)
636            return 11
637
638        async def runner():
639            async with taskgroups.TaskGroup() as g:
640                g.create_task(foo1())
641                g.create_task(foo2())
642
643        r = asyncio.create_task(runner())
644        await asyncio.sleep(0.05)
645        r.cancel()
646
647        with self.assertRaises(asyncio.CancelledError):
648            await r
649
650    async def test_taskgroup_23(self):
651
652        async def do_job(delay):
653            await asyncio.sleep(delay)
654
655        async with taskgroups.TaskGroup() as g:
656            for count in range(10):
657                await asyncio.sleep(0.1)
658                g.create_task(do_job(0.3))
659                if count == 5:
660                    self.assertLess(len(g._tasks), 5)
661            await asyncio.sleep(1.35)
662            self.assertEqual(len(g._tasks), 0)
663
664    async def test_taskgroup_24(self):
665
666        async def root(g):
667            await asyncio.sleep(0.1)
668            g.create_task(coro1(0.1))
669            g.create_task(coro1(0.2))
670
671        async def coro1(delay):
672            await asyncio.sleep(delay)
673
674        async def runner():
675            async with taskgroups.TaskGroup() as g:
676                g.create_task(root(g))
677
678        await runner()
679
680    async def test_taskgroup_25(self):
681        nhydras = 0
682
683        async def hydra(g):
684            nonlocal nhydras
685            nhydras += 1
686            await asyncio.sleep(0.01)
687            g.create_task(hydra(g))
688            g.create_task(hydra(g))
689
690        async def hercules():
691            while nhydras < 10:
692                await asyncio.sleep(0.015)
693            1 / 0
694
695        async def runner():
696            async with taskgroups.TaskGroup() as g:
697                g.create_task(hydra(g))
698                g.create_task(hercules())
699
700        with self.assertRaises(ExceptionGroup) as cm:
701            await runner()
702
703        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
704        self.assertGreaterEqual(nhydras, 10)
705
706    async def test_taskgroup_task_name(self):
707        async def coro():
708            await asyncio.sleep(0)
709        async with taskgroups.TaskGroup() as g:
710            t = g.create_task(coro(), name="yolo")
711            self.assertEqual(t.get_name(), "yolo")
712
713    async def test_taskgroup_task_context(self):
714        cvar = contextvars.ContextVar('cvar')
715
716        async def coro(val):
717            await asyncio.sleep(0)
718            cvar.set(val)
719
720        async with taskgroups.TaskGroup() as g:
721            ctx = contextvars.copy_context()
722            self.assertIsNone(ctx.get(cvar))
723            t1 = g.create_task(coro(1), context=ctx)
724            await t1
725            self.assertEqual(1, ctx.get(cvar))
726            t2 = g.create_task(coro(2), context=ctx)
727            await t2
728            self.assertEqual(2, ctx.get(cvar))
729
730    async def test_taskgroup_no_create_task_after_failure(self):
731        async def coro1():
732            await asyncio.sleep(0.001)
733            1 / 0
734        async def coro2(g):
735            try:
736                await asyncio.sleep(1)
737            except asyncio.CancelledError:
738                with self.assertRaises(RuntimeError):
739                    g.create_task(c1 := coro1())
740                # We still have to await c1 to avoid a warning
741                with self.assertRaises(ZeroDivisionError):
742                    await c1
743
744        with self.assertRaises(ExceptionGroup) as cm:
745            async with taskgroups.TaskGroup() as g:
746                g.create_task(coro1())
747                g.create_task(coro2(g))
748
749        self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
750
751    async def test_taskgroup_context_manager_exit_raises(self):
752        # See https://github.com/python/cpython/issues/95289
753        class CustomException(Exception):
754            pass
755
756        async def raise_exc():
757            raise CustomException
758
759        @contextlib.asynccontextmanager
760        async def database():
761            try:
762                yield
763            finally:
764                raise CustomException
765
766        async def main():
767            task = asyncio.current_task()
768            try:
769                async with taskgroups.TaskGroup() as tg:
770                    async with database():
771                        tg.create_task(raise_exc())
772                        await asyncio.sleep(1)
773            except* CustomException as err:
774                self.assertEqual(task.cancelling(), 0)
775                self.assertEqual(len(err.exceptions), 2)
776
777            else:
778                self.fail('CustomException not raised')
779
780        await asyncio.create_task(main())
781
782
783if __name__ == "__main__":
784    unittest.main()
785