1# Adapted with permission from the EdgeDB project; 2# license: PSFL. 3 4 5import asyncio 6import contextvars 7import contextlib 8from asyncio import taskgroups 9import unittest 10 11 12# To prevent a warning "test altered the execution environment" 13def tearDownModule(): 14 asyncio.set_event_loop_policy(None) 15 16 17class MyExc(Exception): 18 pass 19 20 21class MyBaseExc(BaseException): 22 pass 23 24 25def get_error_types(eg): 26 return {type(exc) for exc in eg.exceptions} 27 28 29class TestTaskGroup(unittest.IsolatedAsyncioTestCase): 30 31 async def test_taskgroup_01(self): 32 33 async def foo1(): 34 await asyncio.sleep(0.1) 35 return 42 36 37 async def foo2(): 38 await asyncio.sleep(0.2) 39 return 11 40 41 async with taskgroups.TaskGroup() as g: 42 t1 = g.create_task(foo1()) 43 t2 = g.create_task(foo2()) 44 45 self.assertEqual(t1.result(), 42) 46 self.assertEqual(t2.result(), 11) 47 48 async def test_taskgroup_02(self): 49 50 async def foo1(): 51 await asyncio.sleep(0.1) 52 return 42 53 54 async def foo2(): 55 await asyncio.sleep(0.2) 56 return 11 57 58 async with taskgroups.TaskGroup() as g: 59 t1 = g.create_task(foo1()) 60 await asyncio.sleep(0.15) 61 t2 = g.create_task(foo2()) 62 63 self.assertEqual(t1.result(), 42) 64 self.assertEqual(t2.result(), 11) 65 66 async def test_taskgroup_03(self): 67 68 async def foo1(): 69 await asyncio.sleep(1) 70 return 42 71 72 async def foo2(): 73 await asyncio.sleep(0.2) 74 return 11 75 76 async with taskgroups.TaskGroup() as g: 77 t1 = g.create_task(foo1()) 78 await asyncio.sleep(0.15) 79 # cancel t1 explicitly, i.e. everything should continue 80 # working as expected. 81 t1.cancel() 82 83 t2 = g.create_task(foo2()) 84 85 self.assertTrue(t1.cancelled()) 86 self.assertEqual(t2.result(), 11) 87 88 async def test_taskgroup_04(self): 89 90 NUM = 0 91 t2_cancel = False 92 t2 = None 93 94 async def foo1(): 95 await asyncio.sleep(0.1) 96 1 / 0 97 98 async def foo2(): 99 nonlocal NUM, t2_cancel 100 try: 101 await asyncio.sleep(1) 102 except asyncio.CancelledError: 103 t2_cancel = True 104 raise 105 NUM += 1 106 107 async def runner(): 108 nonlocal NUM, t2 109 110 async with taskgroups.TaskGroup() as g: 111 g.create_task(foo1()) 112 t2 = g.create_task(foo2()) 113 114 NUM += 10 115 116 with self.assertRaises(ExceptionGroup) as cm: 117 await asyncio.create_task(runner()) 118 119 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 120 121 self.assertEqual(NUM, 0) 122 self.assertTrue(t2_cancel) 123 self.assertTrue(t2.cancelled()) 124 125 async def test_cancel_children_on_child_error(self): 126 # When a child task raises an error, the rest of the children 127 # are cancelled and the errors are gathered into an EG. 128 129 NUM = 0 130 t2_cancel = False 131 runner_cancel = False 132 133 async def foo1(): 134 await asyncio.sleep(0.1) 135 1 / 0 136 137 async def foo2(): 138 nonlocal NUM, t2_cancel 139 try: 140 await asyncio.sleep(5) 141 except asyncio.CancelledError: 142 t2_cancel = True 143 raise 144 NUM += 1 145 146 async def runner(): 147 nonlocal NUM, runner_cancel 148 149 async with taskgroups.TaskGroup() as g: 150 g.create_task(foo1()) 151 g.create_task(foo1()) 152 g.create_task(foo1()) 153 g.create_task(foo2()) 154 try: 155 await asyncio.sleep(10) 156 except asyncio.CancelledError: 157 runner_cancel = True 158 raise 159 160 NUM += 10 161 162 # The 3 foo1 sub tasks can be racy when the host is busy - if the 163 # cancellation happens in the middle, we'll see partial sub errors here 164 with self.assertRaises(ExceptionGroup) as cm: 165 await asyncio.create_task(runner()) 166 167 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 168 self.assertEqual(NUM, 0) 169 self.assertTrue(t2_cancel) 170 self.assertTrue(runner_cancel) 171 172 async def test_cancellation(self): 173 174 NUM = 0 175 176 async def foo(): 177 nonlocal NUM 178 try: 179 await asyncio.sleep(5) 180 except asyncio.CancelledError: 181 NUM += 1 182 raise 183 184 async def runner(): 185 async with taskgroups.TaskGroup() as g: 186 for _ in range(5): 187 g.create_task(foo()) 188 189 r = asyncio.create_task(runner()) 190 await asyncio.sleep(0.1) 191 192 self.assertFalse(r.done()) 193 r.cancel() 194 with self.assertRaises(asyncio.CancelledError) as cm: 195 await r 196 197 self.assertEqual(NUM, 5) 198 199 async def test_taskgroup_07(self): 200 201 NUM = 0 202 203 async def foo(): 204 nonlocal NUM 205 try: 206 await asyncio.sleep(5) 207 except asyncio.CancelledError: 208 NUM += 1 209 raise 210 211 async def runner(): 212 nonlocal NUM 213 async with taskgroups.TaskGroup() as g: 214 for _ in range(5): 215 g.create_task(foo()) 216 217 try: 218 await asyncio.sleep(10) 219 except asyncio.CancelledError: 220 NUM += 10 221 raise 222 223 r = asyncio.create_task(runner()) 224 await asyncio.sleep(0.1) 225 226 self.assertFalse(r.done()) 227 r.cancel() 228 with self.assertRaises(asyncio.CancelledError): 229 await r 230 231 self.assertEqual(NUM, 15) 232 233 async def test_taskgroup_08(self): 234 235 async def foo(): 236 try: 237 await asyncio.sleep(10) 238 finally: 239 1 / 0 240 241 async def runner(): 242 async with taskgroups.TaskGroup() as g: 243 for _ in range(5): 244 g.create_task(foo()) 245 246 await asyncio.sleep(10) 247 248 r = asyncio.create_task(runner()) 249 await asyncio.sleep(0.1) 250 251 self.assertFalse(r.done()) 252 r.cancel() 253 with self.assertRaises(ExceptionGroup) as cm: 254 await r 255 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 256 257 async def test_taskgroup_09(self): 258 259 t1 = t2 = None 260 261 async def foo1(): 262 await asyncio.sleep(1) 263 return 42 264 265 async def foo2(): 266 await asyncio.sleep(2) 267 return 11 268 269 async def runner(): 270 nonlocal t1, t2 271 async with taskgroups.TaskGroup() as g: 272 t1 = g.create_task(foo1()) 273 t2 = g.create_task(foo2()) 274 await asyncio.sleep(0.1) 275 1 / 0 276 277 try: 278 await runner() 279 except ExceptionGroup as t: 280 self.assertEqual(get_error_types(t), {ZeroDivisionError}) 281 else: 282 self.fail('ExceptionGroup was not raised') 283 284 self.assertTrue(t1.cancelled()) 285 self.assertTrue(t2.cancelled()) 286 287 async def test_taskgroup_10(self): 288 289 t1 = t2 = None 290 291 async def foo1(): 292 await asyncio.sleep(1) 293 return 42 294 295 async def foo2(): 296 await asyncio.sleep(2) 297 return 11 298 299 async def runner(): 300 nonlocal t1, t2 301 async with taskgroups.TaskGroup() as g: 302 t1 = g.create_task(foo1()) 303 t2 = g.create_task(foo2()) 304 1 / 0 305 306 try: 307 await runner() 308 except ExceptionGroup as t: 309 self.assertEqual(get_error_types(t), {ZeroDivisionError}) 310 else: 311 self.fail('ExceptionGroup was not raised') 312 313 self.assertTrue(t1.cancelled()) 314 self.assertTrue(t2.cancelled()) 315 316 async def test_taskgroup_11(self): 317 318 async def foo(): 319 try: 320 await asyncio.sleep(10) 321 finally: 322 1 / 0 323 324 async def runner(): 325 async with taskgroups.TaskGroup(): 326 async with taskgroups.TaskGroup() as g2: 327 for _ in range(5): 328 g2.create_task(foo()) 329 330 await asyncio.sleep(10) 331 332 r = asyncio.create_task(runner()) 333 await asyncio.sleep(0.1) 334 335 self.assertFalse(r.done()) 336 r.cancel() 337 with self.assertRaises(ExceptionGroup) as cm: 338 await r 339 340 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) 341 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) 342 343 async def test_taskgroup_12(self): 344 345 async def foo(): 346 try: 347 await asyncio.sleep(10) 348 finally: 349 1 / 0 350 351 async def runner(): 352 async with taskgroups.TaskGroup() as g1: 353 g1.create_task(asyncio.sleep(10)) 354 355 async with taskgroups.TaskGroup() as g2: 356 for _ in range(5): 357 g2.create_task(foo()) 358 359 await asyncio.sleep(10) 360 361 r = asyncio.create_task(runner()) 362 await asyncio.sleep(0.1) 363 364 self.assertFalse(r.done()) 365 r.cancel() 366 with self.assertRaises(ExceptionGroup) as cm: 367 await r 368 369 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) 370 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) 371 372 async def test_taskgroup_13(self): 373 374 async def crash_after(t): 375 await asyncio.sleep(t) 376 raise ValueError(t) 377 378 async def runner(): 379 async with taskgroups.TaskGroup() as g1: 380 g1.create_task(crash_after(0.1)) 381 382 async with taskgroups.TaskGroup() as g2: 383 g2.create_task(crash_after(10)) 384 385 r = asyncio.create_task(runner()) 386 with self.assertRaises(ExceptionGroup) as cm: 387 await r 388 389 self.assertEqual(get_error_types(cm.exception), {ValueError}) 390 391 async def test_taskgroup_14(self): 392 393 async def crash_after(t): 394 await asyncio.sleep(t) 395 raise ValueError(t) 396 397 async def runner(): 398 async with taskgroups.TaskGroup() as g1: 399 g1.create_task(crash_after(10)) 400 401 async with taskgroups.TaskGroup() as g2: 402 g2.create_task(crash_after(0.1)) 403 404 r = asyncio.create_task(runner()) 405 with self.assertRaises(ExceptionGroup) as cm: 406 await r 407 408 self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) 409 self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError}) 410 411 async def test_taskgroup_15(self): 412 413 async def crash_soon(): 414 await asyncio.sleep(0.3) 415 1 / 0 416 417 async def runner(): 418 async with taskgroups.TaskGroup() as g1: 419 g1.create_task(crash_soon()) 420 try: 421 await asyncio.sleep(10) 422 except asyncio.CancelledError: 423 await asyncio.sleep(0.5) 424 raise 425 426 r = asyncio.create_task(runner()) 427 await asyncio.sleep(0.1) 428 429 self.assertFalse(r.done()) 430 r.cancel() 431 with self.assertRaises(ExceptionGroup) as cm: 432 await r 433 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 434 435 async def test_taskgroup_16(self): 436 437 async def crash_soon(): 438 await asyncio.sleep(0.3) 439 1 / 0 440 441 async def nested_runner(): 442 async with taskgroups.TaskGroup() as g1: 443 g1.create_task(crash_soon()) 444 try: 445 await asyncio.sleep(10) 446 except asyncio.CancelledError: 447 await asyncio.sleep(0.5) 448 raise 449 450 async def runner(): 451 t = asyncio.create_task(nested_runner()) 452 await t 453 454 r = asyncio.create_task(runner()) 455 await asyncio.sleep(0.1) 456 457 self.assertFalse(r.done()) 458 r.cancel() 459 with self.assertRaises(ExceptionGroup) as cm: 460 await r 461 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 462 463 async def test_taskgroup_17(self): 464 NUM = 0 465 466 async def runner(): 467 nonlocal NUM 468 async with taskgroups.TaskGroup(): 469 try: 470 await asyncio.sleep(10) 471 except asyncio.CancelledError: 472 NUM += 10 473 raise 474 475 r = asyncio.create_task(runner()) 476 await asyncio.sleep(0.1) 477 478 self.assertFalse(r.done()) 479 r.cancel() 480 with self.assertRaises(asyncio.CancelledError): 481 await r 482 483 self.assertEqual(NUM, 10) 484 485 async def test_taskgroup_18(self): 486 NUM = 0 487 488 async def runner(): 489 nonlocal NUM 490 async with taskgroups.TaskGroup(): 491 try: 492 await asyncio.sleep(10) 493 except asyncio.CancelledError: 494 NUM += 10 495 # This isn't a good idea, but we have to support 496 # this weird case. 497 raise MyExc 498 499 r = asyncio.create_task(runner()) 500 await asyncio.sleep(0.1) 501 502 self.assertFalse(r.done()) 503 r.cancel() 504 505 try: 506 await r 507 except ExceptionGroup as t: 508 self.assertEqual(get_error_types(t),{MyExc}) 509 else: 510 self.fail('ExceptionGroup was not raised') 511 512 self.assertEqual(NUM, 10) 513 514 async def test_taskgroup_19(self): 515 async def crash_soon(): 516 await asyncio.sleep(0.1) 517 1 / 0 518 519 async def nested(): 520 try: 521 await asyncio.sleep(10) 522 finally: 523 raise MyExc 524 525 async def runner(): 526 async with taskgroups.TaskGroup() as g: 527 g.create_task(crash_soon()) 528 await nested() 529 530 r = asyncio.create_task(runner()) 531 try: 532 await r 533 except ExceptionGroup as t: 534 self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError}) 535 else: 536 self.fail('TasgGroupError was not raised') 537 538 async def test_taskgroup_20(self): 539 async def crash_soon(): 540 await asyncio.sleep(0.1) 541 1 / 0 542 543 async def nested(): 544 try: 545 await asyncio.sleep(10) 546 finally: 547 raise KeyboardInterrupt 548 549 async def runner(): 550 async with taskgroups.TaskGroup() as g: 551 g.create_task(crash_soon()) 552 await nested() 553 554 with self.assertRaises(KeyboardInterrupt): 555 await runner() 556 557 async def test_taskgroup_20a(self): 558 async def crash_soon(): 559 await asyncio.sleep(0.1) 560 1 / 0 561 562 async def nested(): 563 try: 564 await asyncio.sleep(10) 565 finally: 566 raise MyBaseExc 567 568 async def runner(): 569 async with taskgroups.TaskGroup() as g: 570 g.create_task(crash_soon()) 571 await nested() 572 573 with self.assertRaises(BaseExceptionGroup) as cm: 574 await runner() 575 576 self.assertEqual( 577 get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError} 578 ) 579 580 async def _test_taskgroup_21(self): 581 # This test doesn't work as asyncio, currently, doesn't 582 # correctly propagate KeyboardInterrupt (or SystemExit) -- 583 # those cause the event loop itself to crash. 584 # (Compare to the previous (passing) test -- that one raises 585 # a plain exception but raises KeyboardInterrupt in nested(); 586 # this test does it the other way around.) 587 588 async def crash_soon(): 589 await asyncio.sleep(0.1) 590 raise KeyboardInterrupt 591 592 async def nested(): 593 try: 594 await asyncio.sleep(10) 595 finally: 596 raise TypeError 597 598 async def runner(): 599 async with taskgroups.TaskGroup() as g: 600 g.create_task(crash_soon()) 601 await nested() 602 603 with self.assertRaises(KeyboardInterrupt): 604 await runner() 605 606 async def test_taskgroup_21a(self): 607 608 async def crash_soon(): 609 await asyncio.sleep(0.1) 610 raise MyBaseExc 611 612 async def nested(): 613 try: 614 await asyncio.sleep(10) 615 finally: 616 raise TypeError 617 618 async def runner(): 619 async with taskgroups.TaskGroup() as g: 620 g.create_task(crash_soon()) 621 await nested() 622 623 with self.assertRaises(BaseExceptionGroup) as cm: 624 await runner() 625 626 self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError}) 627 628 async def test_taskgroup_22(self): 629 630 async def foo1(): 631 await asyncio.sleep(1) 632 return 42 633 634 async def foo2(): 635 await asyncio.sleep(2) 636 return 11 637 638 async def runner(): 639 async with taskgroups.TaskGroup() as g: 640 g.create_task(foo1()) 641 g.create_task(foo2()) 642 643 r = asyncio.create_task(runner()) 644 await asyncio.sleep(0.05) 645 r.cancel() 646 647 with self.assertRaises(asyncio.CancelledError): 648 await r 649 650 async def test_taskgroup_23(self): 651 652 async def do_job(delay): 653 await asyncio.sleep(delay) 654 655 async with taskgroups.TaskGroup() as g: 656 for count in range(10): 657 await asyncio.sleep(0.1) 658 g.create_task(do_job(0.3)) 659 if count == 5: 660 self.assertLess(len(g._tasks), 5) 661 await asyncio.sleep(1.35) 662 self.assertEqual(len(g._tasks), 0) 663 664 async def test_taskgroup_24(self): 665 666 async def root(g): 667 await asyncio.sleep(0.1) 668 g.create_task(coro1(0.1)) 669 g.create_task(coro1(0.2)) 670 671 async def coro1(delay): 672 await asyncio.sleep(delay) 673 674 async def runner(): 675 async with taskgroups.TaskGroup() as g: 676 g.create_task(root(g)) 677 678 await runner() 679 680 async def test_taskgroup_25(self): 681 nhydras = 0 682 683 async def hydra(g): 684 nonlocal nhydras 685 nhydras += 1 686 await asyncio.sleep(0.01) 687 g.create_task(hydra(g)) 688 g.create_task(hydra(g)) 689 690 async def hercules(): 691 while nhydras < 10: 692 await asyncio.sleep(0.015) 693 1 / 0 694 695 async def runner(): 696 async with taskgroups.TaskGroup() as g: 697 g.create_task(hydra(g)) 698 g.create_task(hercules()) 699 700 with self.assertRaises(ExceptionGroup) as cm: 701 await runner() 702 703 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 704 self.assertGreaterEqual(nhydras, 10) 705 706 async def test_taskgroup_task_name(self): 707 async def coro(): 708 await asyncio.sleep(0) 709 async with taskgroups.TaskGroup() as g: 710 t = g.create_task(coro(), name="yolo") 711 self.assertEqual(t.get_name(), "yolo") 712 713 async def test_taskgroup_task_context(self): 714 cvar = contextvars.ContextVar('cvar') 715 716 async def coro(val): 717 await asyncio.sleep(0) 718 cvar.set(val) 719 720 async with taskgroups.TaskGroup() as g: 721 ctx = contextvars.copy_context() 722 self.assertIsNone(ctx.get(cvar)) 723 t1 = g.create_task(coro(1), context=ctx) 724 await t1 725 self.assertEqual(1, ctx.get(cvar)) 726 t2 = g.create_task(coro(2), context=ctx) 727 await t2 728 self.assertEqual(2, ctx.get(cvar)) 729 730 async def test_taskgroup_no_create_task_after_failure(self): 731 async def coro1(): 732 await asyncio.sleep(0.001) 733 1 / 0 734 async def coro2(g): 735 try: 736 await asyncio.sleep(1) 737 except asyncio.CancelledError: 738 with self.assertRaises(RuntimeError): 739 g.create_task(c1 := coro1()) 740 # We still have to await c1 to avoid a warning 741 with self.assertRaises(ZeroDivisionError): 742 await c1 743 744 with self.assertRaises(ExceptionGroup) as cm: 745 async with taskgroups.TaskGroup() as g: 746 g.create_task(coro1()) 747 g.create_task(coro2(g)) 748 749 self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) 750 751 async def test_taskgroup_context_manager_exit_raises(self): 752 # See https://github.com/python/cpython/issues/95289 753 class CustomException(Exception): 754 pass 755 756 async def raise_exc(): 757 raise CustomException 758 759 @contextlib.asynccontextmanager 760 async def database(): 761 try: 762 yield 763 finally: 764 raise CustomException 765 766 async def main(): 767 task = asyncio.current_task() 768 try: 769 async with taskgroups.TaskGroup() as tg: 770 async with database(): 771 tg.create_task(raise_exc()) 772 await asyncio.sleep(1) 773 except* CustomException as err: 774 self.assertEqual(task.cancelling(), 0) 775 self.assertEqual(len(err.exceptions), 2) 776 777 else: 778 self.fail('CustomException not raised') 779 780 await asyncio.create_task(main()) 781 782 783if __name__ == "__main__": 784 unittest.main() 785