1import asyncio 2from contextlib import ( 3 asynccontextmanager, AbstractAsyncContextManager, 4 AsyncExitStack, nullcontext, aclosing, contextmanager) 5import functools 6from test import support 7import unittest 8import traceback 9 10from test.test_contextlib import TestBaseExitStack 11 12support.requires_working_socket(module=True) 13 14def _async_test(func): 15 """Decorator to turn an async function into a test case.""" 16 @functools.wraps(func) 17 def wrapper(*args, **kwargs): 18 coro = func(*args, **kwargs) 19 asyncio.run(coro) 20 return wrapper 21 22def tearDownModule(): 23 asyncio.set_event_loop_policy(None) 24 25 26class TestAbstractAsyncContextManager(unittest.TestCase): 27 28 @_async_test 29 async def test_enter(self): 30 class DefaultEnter(AbstractAsyncContextManager): 31 async def __aexit__(self, *args): 32 await super().__aexit__(*args) 33 34 manager = DefaultEnter() 35 self.assertIs(await manager.__aenter__(), manager) 36 37 async with manager as context: 38 self.assertIs(manager, context) 39 40 @_async_test 41 async def test_async_gen_propagates_generator_exit(self): 42 # A regression test for https://bugs.python.org/issue33786. 43 44 @asynccontextmanager 45 async def ctx(): 46 yield 47 48 async def gen(): 49 async with ctx(): 50 yield 11 51 52 ret = [] 53 exc = ValueError(22) 54 with self.assertRaises(ValueError): 55 async with ctx(): 56 async for val in gen(): 57 ret.append(val) 58 raise exc 59 60 self.assertEqual(ret, [11]) 61 62 def test_exit_is_abstract(self): 63 class MissingAexit(AbstractAsyncContextManager): 64 pass 65 66 with self.assertRaises(TypeError): 67 MissingAexit() 68 69 def test_structural_subclassing(self): 70 class ManagerFromScratch: 71 async def __aenter__(self): 72 return self 73 async def __aexit__(self, exc_type, exc_value, traceback): 74 return None 75 76 self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager)) 77 78 class DefaultEnter(AbstractAsyncContextManager): 79 async def __aexit__(self, *args): 80 await super().__aexit__(*args) 81 82 self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager)) 83 84 class NoneAenter(ManagerFromScratch): 85 __aenter__ = None 86 87 self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager)) 88 89 class NoneAexit(ManagerFromScratch): 90 __aexit__ = None 91 92 self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager)) 93 94 95class AsyncContextManagerTestCase(unittest.TestCase): 96 97 @_async_test 98 async def test_contextmanager_plain(self): 99 state = [] 100 @asynccontextmanager 101 async def woohoo(): 102 state.append(1) 103 yield 42 104 state.append(999) 105 async with woohoo() as x: 106 self.assertEqual(state, [1]) 107 self.assertEqual(x, 42) 108 state.append(x) 109 self.assertEqual(state, [1, 42, 999]) 110 111 @_async_test 112 async def test_contextmanager_finally(self): 113 state = [] 114 @asynccontextmanager 115 async def woohoo(): 116 state.append(1) 117 try: 118 yield 42 119 finally: 120 state.append(999) 121 with self.assertRaises(ZeroDivisionError): 122 async with woohoo() as x: 123 self.assertEqual(state, [1]) 124 self.assertEqual(x, 42) 125 state.append(x) 126 raise ZeroDivisionError() 127 self.assertEqual(state, [1, 42, 999]) 128 129 @_async_test 130 async def test_contextmanager_traceback(self): 131 @asynccontextmanager 132 async def f(): 133 yield 134 135 try: 136 async with f(): 137 1/0 138 except ZeroDivisionError as e: 139 frames = traceback.extract_tb(e.__traceback__) 140 141 self.assertEqual(len(frames), 1) 142 self.assertEqual(frames[0].name, 'test_contextmanager_traceback') 143 self.assertEqual(frames[0].line, '1/0') 144 145 # Repeat with RuntimeError (which goes through a different code path) 146 class RuntimeErrorSubclass(RuntimeError): 147 pass 148 149 try: 150 async with f(): 151 raise RuntimeErrorSubclass(42) 152 except RuntimeErrorSubclass as e: 153 frames = traceback.extract_tb(e.__traceback__) 154 155 self.assertEqual(len(frames), 1) 156 self.assertEqual(frames[0].name, 'test_contextmanager_traceback') 157 self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)') 158 159 class StopIterationSubclass(StopIteration): 160 pass 161 162 class StopAsyncIterationSubclass(StopAsyncIteration): 163 pass 164 165 for stop_exc in ( 166 StopIteration('spam'), 167 StopAsyncIteration('ham'), 168 StopIterationSubclass('spam'), 169 StopAsyncIterationSubclass('spam') 170 ): 171 with self.subTest(type=type(stop_exc)): 172 try: 173 async with f(): 174 raise stop_exc 175 except type(stop_exc) as e: 176 self.assertIs(e, stop_exc) 177 frames = traceback.extract_tb(e.__traceback__) 178 else: 179 self.fail(f'{stop_exc} was suppressed') 180 181 self.assertEqual(len(frames), 1) 182 self.assertEqual(frames[0].name, 'test_contextmanager_traceback') 183 self.assertEqual(frames[0].line, 'raise stop_exc') 184 185 @_async_test 186 async def test_contextmanager_no_reraise(self): 187 @asynccontextmanager 188 async def whee(): 189 yield 190 ctx = whee() 191 await ctx.__aenter__() 192 # Calling __aexit__ should not result in an exception 193 self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) 194 195 @_async_test 196 async def test_contextmanager_trap_yield_after_throw(self): 197 @asynccontextmanager 198 async def whoo(): 199 try: 200 yield 201 except: 202 yield 203 ctx = whoo() 204 await ctx.__aenter__() 205 with self.assertRaises(RuntimeError): 206 await ctx.__aexit__(TypeError, TypeError('foo'), None) 207 208 @_async_test 209 async def test_contextmanager_trap_no_yield(self): 210 @asynccontextmanager 211 async def whoo(): 212 if False: 213 yield 214 ctx = whoo() 215 with self.assertRaises(RuntimeError): 216 await ctx.__aenter__() 217 218 @_async_test 219 async def test_contextmanager_trap_second_yield(self): 220 @asynccontextmanager 221 async def whoo(): 222 yield 223 yield 224 ctx = whoo() 225 await ctx.__aenter__() 226 with self.assertRaises(RuntimeError): 227 await ctx.__aexit__(None, None, None) 228 229 @_async_test 230 async def test_contextmanager_non_normalised(self): 231 @asynccontextmanager 232 async def whoo(): 233 try: 234 yield 235 except RuntimeError: 236 raise SyntaxError 237 238 ctx = whoo() 239 await ctx.__aenter__() 240 with self.assertRaises(SyntaxError): 241 await ctx.__aexit__(RuntimeError, None, None) 242 243 @_async_test 244 async def test_contextmanager_except(self): 245 state = [] 246 @asynccontextmanager 247 async def woohoo(): 248 state.append(1) 249 try: 250 yield 42 251 except ZeroDivisionError as e: 252 state.append(e.args[0]) 253 self.assertEqual(state, [1, 42, 999]) 254 async with woohoo() as x: 255 self.assertEqual(state, [1]) 256 self.assertEqual(x, 42) 257 state.append(x) 258 raise ZeroDivisionError(999) 259 self.assertEqual(state, [1, 42, 999]) 260 261 @_async_test 262 async def test_contextmanager_except_stopiter(self): 263 @asynccontextmanager 264 async def woohoo(): 265 yield 266 267 class StopIterationSubclass(StopIteration): 268 pass 269 270 class StopAsyncIterationSubclass(StopAsyncIteration): 271 pass 272 273 for stop_exc in ( 274 StopIteration('spam'), 275 StopAsyncIteration('ham'), 276 StopIterationSubclass('spam'), 277 StopAsyncIterationSubclass('spam') 278 ): 279 with self.subTest(type=type(stop_exc)): 280 try: 281 async with woohoo(): 282 raise stop_exc 283 except Exception as ex: 284 self.assertIs(ex, stop_exc) 285 else: 286 self.fail(f'{stop_exc} was suppressed') 287 288 @_async_test 289 async def test_contextmanager_wrap_runtimeerror(self): 290 @asynccontextmanager 291 async def woohoo(): 292 try: 293 yield 294 except Exception as exc: 295 raise RuntimeError(f'caught {exc}') from exc 296 297 with self.assertRaises(RuntimeError): 298 async with woohoo(): 299 1 / 0 300 301 # If the context manager wrapped StopAsyncIteration in a RuntimeError, 302 # we also unwrap it, because we can't tell whether the wrapping was 303 # done by the generator machinery or by the generator itself. 304 with self.assertRaises(StopAsyncIteration): 305 async with woohoo(): 306 raise StopAsyncIteration 307 308 def _create_contextmanager_attribs(self): 309 def attribs(**kw): 310 def decorate(func): 311 for k,v in kw.items(): 312 setattr(func,k,v) 313 return func 314 return decorate 315 @asynccontextmanager 316 @attribs(foo='bar') 317 async def baz(spam): 318 """Whee!""" 319 yield 320 return baz 321 322 def test_contextmanager_attribs(self): 323 baz = self._create_contextmanager_attribs() 324 self.assertEqual(baz.__name__,'baz') 325 self.assertEqual(baz.foo, 'bar') 326 327 @support.requires_docstrings 328 def test_contextmanager_doc_attrib(self): 329 baz = self._create_contextmanager_attribs() 330 self.assertEqual(baz.__doc__, "Whee!") 331 332 @support.requires_docstrings 333 @_async_test 334 async def test_instance_docstring_given_cm_docstring(self): 335 baz = self._create_contextmanager_attribs()(None) 336 self.assertEqual(baz.__doc__, "Whee!") 337 async with baz: 338 pass # suppress warning 339 340 @_async_test 341 async def test_keywords(self): 342 # Ensure no keyword arguments are inhibited 343 @asynccontextmanager 344 async def woohoo(self, func, args, kwds): 345 yield (self, func, args, kwds) 346 async with woohoo(self=11, func=22, args=33, kwds=44) as target: 347 self.assertEqual(target, (11, 22, 33, 44)) 348 349 @_async_test 350 async def test_recursive(self): 351 depth = 0 352 ncols = 0 353 354 @asynccontextmanager 355 async def woohoo(): 356 nonlocal ncols 357 ncols += 1 358 359 nonlocal depth 360 before = depth 361 depth += 1 362 yield 363 depth -= 1 364 self.assertEqual(depth, before) 365 366 @woohoo() 367 async def recursive(): 368 if depth < 10: 369 await recursive() 370 371 await recursive() 372 373 self.assertEqual(ncols, 10) 374 self.assertEqual(depth, 0) 375 376 @_async_test 377 async def test_decorator(self): 378 entered = False 379 380 @asynccontextmanager 381 async def context(): 382 nonlocal entered 383 entered = True 384 yield 385 entered = False 386 387 @context() 388 async def test(): 389 self.assertTrue(entered) 390 391 self.assertFalse(entered) 392 await test() 393 self.assertFalse(entered) 394 395 @_async_test 396 async def test_decorator_with_exception(self): 397 entered = False 398 399 @asynccontextmanager 400 async def context(): 401 nonlocal entered 402 try: 403 entered = True 404 yield 405 finally: 406 entered = False 407 408 @context() 409 async def test(): 410 self.assertTrue(entered) 411 raise NameError('foo') 412 413 self.assertFalse(entered) 414 with self.assertRaisesRegex(NameError, 'foo'): 415 await test() 416 self.assertFalse(entered) 417 418 @_async_test 419 async def test_decorating_method(self): 420 421 @asynccontextmanager 422 async def context(): 423 yield 424 425 426 class Test(object): 427 428 @context() 429 async def method(self, a, b, c=None): 430 self.a = a 431 self.b = b 432 self.c = c 433 434 # these tests are for argument passing when used as a decorator 435 test = Test() 436 await test.method(1, 2) 437 self.assertEqual(test.a, 1) 438 self.assertEqual(test.b, 2) 439 self.assertEqual(test.c, None) 440 441 test = Test() 442 await test.method('a', 'b', 'c') 443 self.assertEqual(test.a, 'a') 444 self.assertEqual(test.b, 'b') 445 self.assertEqual(test.c, 'c') 446 447 test = Test() 448 await test.method(a=1, b=2) 449 self.assertEqual(test.a, 1) 450 self.assertEqual(test.b, 2) 451 452 453class AclosingTestCase(unittest.TestCase): 454 455 @support.requires_docstrings 456 def test_instance_docs(self): 457 cm_docstring = aclosing.__doc__ 458 obj = aclosing(None) 459 self.assertEqual(obj.__doc__, cm_docstring) 460 461 @_async_test 462 async def test_aclosing(self): 463 state = [] 464 class C: 465 async def aclose(self): 466 state.append(1) 467 x = C() 468 self.assertEqual(state, []) 469 async with aclosing(x) as y: 470 self.assertEqual(x, y) 471 self.assertEqual(state, [1]) 472 473 @_async_test 474 async def test_aclosing_error(self): 475 state = [] 476 class C: 477 async def aclose(self): 478 state.append(1) 479 x = C() 480 self.assertEqual(state, []) 481 with self.assertRaises(ZeroDivisionError): 482 async with aclosing(x) as y: 483 self.assertEqual(x, y) 484 1 / 0 485 self.assertEqual(state, [1]) 486 487 @_async_test 488 async def test_aclosing_bpo41229(self): 489 state = [] 490 491 @contextmanager 492 def sync_resource(): 493 try: 494 yield 495 finally: 496 state.append(1) 497 498 async def agenfunc(): 499 with sync_resource(): 500 yield -1 501 yield -2 502 503 x = agenfunc() 504 self.assertEqual(state, []) 505 with self.assertRaises(ZeroDivisionError): 506 async with aclosing(x) as y: 507 self.assertEqual(x, y) 508 self.assertEqual(-1, await x.__anext__()) 509 1 / 0 510 self.assertEqual(state, [1]) 511 512 513class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase): 514 class SyncAsyncExitStack(AsyncExitStack): 515 @staticmethod 516 def run_coroutine(coro): 517 loop = asyncio.get_event_loop_policy().get_event_loop() 518 t = loop.create_task(coro) 519 t.add_done_callback(lambda f: loop.stop()) 520 loop.run_forever() 521 522 exc = t.exception() 523 if not exc: 524 return t.result() 525 else: 526 context = exc.__context__ 527 528 try: 529 raise exc 530 except: 531 exc.__context__ = context 532 raise exc 533 534 def close(self): 535 return self.run_coroutine(self.aclose()) 536 537 def __enter__(self): 538 return self.run_coroutine(self.__aenter__()) 539 540 def __exit__(self, *exc_details): 541 return self.run_coroutine(self.__aexit__(*exc_details)) 542 543 exit_stack = SyncAsyncExitStack 544 callback_error_internal_frames = [ 545 ('__exit__', 'return self.run_coroutine(self.__aexit__(*exc_details))'), 546 ('run_coroutine', 'raise exc'), 547 ('run_coroutine', 'raise exc'), 548 ('__aexit__', 'raise exc_details[1]'), 549 ('__aexit__', 'cb_suppress = cb(*exc_details)'), 550 ] 551 552 def setUp(self): 553 self.loop = asyncio.new_event_loop() 554 asyncio.set_event_loop(self.loop) 555 self.addCleanup(self.loop.close) 556 self.addCleanup(asyncio.set_event_loop_policy, None) 557 558 @_async_test 559 async def test_async_callback(self): 560 expected = [ 561 ((), {}), 562 ((1,), {}), 563 ((1,2), {}), 564 ((), dict(example=1)), 565 ((1,), dict(example=1)), 566 ((1,2), dict(example=1)), 567 ] 568 result = [] 569 async def _exit(*args, **kwds): 570 """Test metadata propagation""" 571 result.append((args, kwds)) 572 573 async with AsyncExitStack() as stack: 574 for args, kwds in reversed(expected): 575 if args and kwds: 576 f = stack.push_async_callback(_exit, *args, **kwds) 577 elif args: 578 f = stack.push_async_callback(_exit, *args) 579 elif kwds: 580 f = stack.push_async_callback(_exit, **kwds) 581 else: 582 f = stack.push_async_callback(_exit) 583 self.assertIs(f, _exit) 584 for wrapper in stack._exit_callbacks: 585 self.assertIs(wrapper[1].__wrapped__, _exit) 586 self.assertNotEqual(wrapper[1].__name__, _exit.__name__) 587 self.assertIsNone(wrapper[1].__doc__, _exit.__doc__) 588 589 self.assertEqual(result, expected) 590 591 result = [] 592 async with AsyncExitStack() as stack: 593 with self.assertRaises(TypeError): 594 stack.push_async_callback(arg=1) 595 with self.assertRaises(TypeError): 596 self.exit_stack.push_async_callback(arg=2) 597 with self.assertRaises(TypeError): 598 stack.push_async_callback(callback=_exit, arg=3) 599 self.assertEqual(result, []) 600 601 @_async_test 602 async def test_async_push(self): 603 exc_raised = ZeroDivisionError 604 async def _expect_exc(exc_type, exc, exc_tb): 605 self.assertIs(exc_type, exc_raised) 606 async def _suppress_exc(*exc_details): 607 return True 608 async def _expect_ok(exc_type, exc, exc_tb): 609 self.assertIsNone(exc_type) 610 self.assertIsNone(exc) 611 self.assertIsNone(exc_tb) 612 class ExitCM(object): 613 def __init__(self, check_exc): 614 self.check_exc = check_exc 615 async def __aenter__(self): 616 self.fail("Should not be called!") 617 async def __aexit__(self, *exc_details): 618 await self.check_exc(*exc_details) 619 620 async with self.exit_stack() as stack: 621 stack.push_async_exit(_expect_ok) 622 self.assertIs(stack._exit_callbacks[-1][1], _expect_ok) 623 cm = ExitCM(_expect_ok) 624 stack.push_async_exit(cm) 625 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 626 stack.push_async_exit(_suppress_exc) 627 self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc) 628 cm = ExitCM(_expect_exc) 629 stack.push_async_exit(cm) 630 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 631 stack.push_async_exit(_expect_exc) 632 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 633 stack.push_async_exit(_expect_exc) 634 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc) 635 1/0 636 637 @_async_test 638 async def test_enter_async_context(self): 639 class TestCM(object): 640 async def __aenter__(self): 641 result.append(1) 642 async def __aexit__(self, *exc_details): 643 result.append(3) 644 645 result = [] 646 cm = TestCM() 647 648 async with AsyncExitStack() as stack: 649 @stack.push_async_callback # Registered first => cleaned up last 650 async def _exit(): 651 result.append(4) 652 self.assertIsNotNone(_exit) 653 await stack.enter_async_context(cm) 654 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm) 655 result.append(2) 656 657 self.assertEqual(result, [1, 2, 3, 4]) 658 659 @_async_test 660 async def test_enter_async_context_errors(self): 661 class LacksEnterAndExit: 662 pass 663 class LacksEnter: 664 async def __aexit__(self, *exc_info): 665 pass 666 class LacksExit: 667 async def __aenter__(self): 668 pass 669 670 async with self.exit_stack() as stack: 671 with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): 672 await stack.enter_async_context(LacksEnterAndExit()) 673 with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): 674 await stack.enter_async_context(LacksEnter()) 675 with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): 676 await stack.enter_async_context(LacksExit()) 677 self.assertFalse(stack._exit_callbacks) 678 679 @_async_test 680 async def test_async_exit_exception_chaining(self): 681 # Ensure exception chaining matches the reference behaviour 682 async def raise_exc(exc): 683 raise exc 684 685 saved_details = None 686 async def suppress_exc(*exc_details): 687 nonlocal saved_details 688 saved_details = exc_details 689 return True 690 691 try: 692 async with self.exit_stack() as stack: 693 stack.push_async_callback(raise_exc, IndexError) 694 stack.push_async_callback(raise_exc, KeyError) 695 stack.push_async_callback(raise_exc, AttributeError) 696 stack.push_async_exit(suppress_exc) 697 stack.push_async_callback(raise_exc, ValueError) 698 1 / 0 699 except IndexError as exc: 700 self.assertIsInstance(exc.__context__, KeyError) 701 self.assertIsInstance(exc.__context__.__context__, AttributeError) 702 # Inner exceptions were suppressed 703 self.assertIsNone(exc.__context__.__context__.__context__) 704 else: 705 self.fail("Expected IndexError, but no exception was raised") 706 # Check the inner exceptions 707 inner_exc = saved_details[1] 708 self.assertIsInstance(inner_exc, ValueError) 709 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError) 710 711 @_async_test 712 async def test_async_exit_exception_explicit_none_context(self): 713 # Ensure AsyncExitStack chaining matches actual nested `with` statements 714 # regarding explicit __context__ = None. 715 716 class MyException(Exception): 717 pass 718 719 @asynccontextmanager 720 async def my_cm(): 721 try: 722 yield 723 except BaseException: 724 exc = MyException() 725 try: 726 raise exc 727 finally: 728 exc.__context__ = None 729 730 @asynccontextmanager 731 async def my_cm_with_exit_stack(): 732 async with self.exit_stack() as stack: 733 await stack.enter_async_context(my_cm()) 734 yield stack 735 736 for cm in (my_cm, my_cm_with_exit_stack): 737 with self.subTest(): 738 try: 739 async with cm(): 740 raise IndexError() 741 except MyException as exc: 742 self.assertIsNone(exc.__context__) 743 else: 744 self.fail("Expected IndexError, but no exception was raised") 745 746 @_async_test 747 async def test_instance_bypass_async(self): 748 class Example(object): pass 749 cm = Example() 750 cm.__aenter__ = object() 751 cm.__aexit__ = object() 752 stack = self.exit_stack() 753 with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): 754 await stack.enter_async_context(cm) 755 stack.push_async_exit(cm) 756 self.assertIs(stack._exit_callbacks[-1][1], cm) 757 758 759class TestAsyncNullcontext(unittest.TestCase): 760 @_async_test 761 async def test_async_nullcontext(self): 762 class C: 763 pass 764 c = C() 765 async with nullcontext(c) as c_in: 766 self.assertIs(c_in, c) 767 768 769if __name__ == '__main__': 770 unittest.main() 771