xref: /aosp_15_r20/prebuilts/build-tools/common/py3-stdlib/unittest/async_case.py (revision cda5da8d549138a6648c5ee6d7a49cf8f4a657be)
1import asyncio
2import contextvars
3import inspect
4import warnings
5
6from .case import TestCase
7
8
9class IsolatedAsyncioTestCase(TestCase):
10    # Names intentionally have a long prefix
11    # to reduce a chance of clashing with user-defined attributes
12    # from inherited test case
13    #
14    # The class doesn't call loop.run_until_complete(self.setUp()) and family
15    # but uses a different approach:
16    # 1. create a long-running task that reads self.setUp()
17    #    awaitable from queue along with a future
18    # 2. await the awaitable object passing in and set the result
19    #    into the future object
20    # 3. Outer code puts the awaitable and the future object into a queue
21    #    with waiting for the future
22    # The trick is necessary because every run_until_complete() call
23    # creates a new task with embedded ContextVar context.
24    # To share contextvars between setUp(), test and tearDown() we need to execute
25    # them inside the same task.
26
27    # Note: the test case modifies event loop policy if the policy was not instantiated
28    # yet.
29    # asyncio.get_event_loop_policy() creates a default policy on demand but never
30    # returns None
31    # I believe this is not an issue in user level tests but python itself for testing
32    # should reset a policy in every test module
33    # by calling asyncio.set_event_loop_policy(None) in tearDownModule()
34
35    def __init__(self, methodName='runTest'):
36        super().__init__(methodName)
37        self._asyncioRunner = None
38        self._asyncioTestContext = contextvars.copy_context()
39
40    async def asyncSetUp(self):
41        pass
42
43    async def asyncTearDown(self):
44        pass
45
46    def addAsyncCleanup(self, func, /, *args, **kwargs):
47        # A trivial trampoline to addCleanup()
48        # the function exists because it has a different semantics
49        # and signature:
50        # addCleanup() accepts regular functions
51        # but addAsyncCleanup() accepts coroutines
52        #
53        # We intentionally don't add inspect.iscoroutinefunction() check
54        # for func argument because there is no way
55        # to check for async function reliably:
56        # 1. It can be "async def func()" itself
57        # 2. Class can implement "async def __call__()" method
58        # 3. Regular "def func()" that returns awaitable object
59        self.addCleanup(*(func, *args), **kwargs)
60
61    async def enterAsyncContext(self, cm):
62        """Enters the supplied asynchronous context manager.
63
64        If successful, also adds its __aexit__ method as a cleanup
65        function and returns the result of the __aenter__ method.
66        """
67        # We look up the special methods on the type to match the with
68        # statement.
69        cls = type(cm)
70        try:
71            enter = cls.__aenter__
72            exit = cls.__aexit__
73        except AttributeError:
74            raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
75                            f"not support the asynchronous context manager protocol"
76                           ) from None
77        result = await enter(cm)
78        self.addAsyncCleanup(exit, cm, None, None, None)
79        return result
80
81    def _callSetUp(self):
82        # Force loop to be initialized and set as the current loop
83        # so that setUp functions can use get_event_loop() and get the
84        # correct loop instance.
85        self._asyncioRunner.get_loop()
86        self._asyncioTestContext.run(self.setUp)
87        self._callAsync(self.asyncSetUp)
88
89    def _callTestMethod(self, method):
90        if self._callMaybeAsync(method) is not None:
91            warnings.warn(f'It is deprecated to return a value that is not None from a '
92                          f'test case ({method})', DeprecationWarning, stacklevel=4)
93
94    def _callTearDown(self):
95        self._callAsync(self.asyncTearDown)
96        self._asyncioTestContext.run(self.tearDown)
97
98    def _callCleanup(self, function, *args, **kwargs):
99        self._callMaybeAsync(function, *args, **kwargs)
100
101    def _callAsync(self, func, /, *args, **kwargs):
102        assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
103        assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function'
104        return self._asyncioRunner.run(
105            func(*args, **kwargs),
106            context=self._asyncioTestContext
107        )
108
109    def _callMaybeAsync(self, func, /, *args, **kwargs):
110        assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
111        if inspect.iscoroutinefunction(func):
112            return self._asyncioRunner.run(
113                func(*args, **kwargs),
114                context=self._asyncioTestContext,
115            )
116        else:
117            return self._asyncioTestContext.run(func, *args, **kwargs)
118
119    def _setupAsyncioRunner(self):
120        assert self._asyncioRunner is None, 'asyncio runner is already initialized'
121        runner = asyncio.Runner(debug=True)
122        self._asyncioRunner = runner
123
124    def _tearDownAsyncioRunner(self):
125        runner = self._asyncioRunner
126        runner.close()
127
128    def run(self, result=None):
129        self._setupAsyncioRunner()
130        try:
131            return super().run(result)
132        finally:
133            self._tearDownAsyncioRunner()
134
135    def debug(self):
136        self._setupAsyncioRunner()
137        super().debug()
138        self._tearDownAsyncioRunner()
139
140    def __del__(self):
141        if self._asyncioRunner is not None:
142            self._tearDownAsyncioRunner()
143