1import _thread
2import asyncio
3import contextvars
4import gc
5import re
6import signal
7import threading
8import unittest
9from test.test_asyncio import utils as test_utils
10from unittest import mock
11from unittest.mock import patch
12
13
14def tearDownModule():
15    asyncio.set_event_loop_policy(None)
16
17
18def interrupt_self():
19    _thread.interrupt_main()
20
21
22class TestPolicy(asyncio.AbstractEventLoopPolicy):
23
24    def __init__(self, loop_factory):
25        self.loop_factory = loop_factory
26        self.loop = None
27
28    def get_event_loop(self):
29        # shouldn't ever be called by asyncio.run()
30        raise RuntimeError
31
32    def new_event_loop(self):
33        return self.loop_factory()
34
35    def set_event_loop(self, loop):
36        if loop is not None:
37            # we want to check if the loop is closed
38            # in BaseTest.tearDown
39            self.loop = loop
40
41
42class BaseTest(unittest.TestCase):
43
44    def new_loop(self):
45        loop = asyncio.BaseEventLoop()
46        loop._process_events = mock.Mock()
47        # Mock waking event loop from select
48        loop._write_to_self = mock.Mock()
49        loop._write_to_self.return_value = None
50        loop._selector = mock.Mock()
51        loop._selector.select.return_value = ()
52        loop.shutdown_ag_run = False
53
54        async def shutdown_asyncgens():
55            loop.shutdown_ag_run = True
56        loop.shutdown_asyncgens = shutdown_asyncgens
57
58        return loop
59
60    def setUp(self):
61        super().setUp()
62
63        policy = TestPolicy(self.new_loop)
64        asyncio.set_event_loop_policy(policy)
65
66    def tearDown(self):
67        policy = asyncio.get_event_loop_policy()
68        if policy.loop is not None:
69            self.assertTrue(policy.loop.is_closed())
70            self.assertTrue(policy.loop.shutdown_ag_run)
71
72        asyncio.set_event_loop_policy(None)
73        super().tearDown()
74
75
76class RunTests(BaseTest):
77
78    def test_asyncio_run_return(self):
79        async def main():
80            await asyncio.sleep(0)
81            return 42
82
83        self.assertEqual(asyncio.run(main()), 42)
84
85    def test_asyncio_run_raises(self):
86        async def main():
87            await asyncio.sleep(0)
88            raise ValueError('spam')
89
90        with self.assertRaisesRegex(ValueError, 'spam'):
91            asyncio.run(main())
92
93    def test_asyncio_run_only_coro(self):
94        for o in {1, lambda: None}:
95            with self.subTest(obj=o), \
96                    self.assertRaisesRegex(ValueError,
97                                           'a coroutine was expected'):
98                asyncio.run(o)
99
100    def test_asyncio_run_debug(self):
101        async def main(expected):
102            loop = asyncio.get_event_loop()
103            self.assertIs(loop.get_debug(), expected)
104
105        asyncio.run(main(False))
106        asyncio.run(main(True), debug=True)
107        with mock.patch('asyncio.coroutines._is_debug_mode', lambda: True):
108            asyncio.run(main(True))
109            asyncio.run(main(False), debug=False)
110
111    def test_asyncio_run_from_running_loop(self):
112        async def main():
113            coro = main()
114            try:
115                asyncio.run(coro)
116            finally:
117                coro.close()  # Suppress ResourceWarning
118
119        with self.assertRaisesRegex(RuntimeError,
120                                    'cannot be called from a running'):
121            asyncio.run(main())
122
123    def test_asyncio_run_cancels_hanging_tasks(self):
124        lo_task = None
125
126        async def leftover():
127            await asyncio.sleep(0.1)
128
129        async def main():
130            nonlocal lo_task
131            lo_task = asyncio.create_task(leftover())
132            return 123
133
134        self.assertEqual(asyncio.run(main()), 123)
135        self.assertTrue(lo_task.done())
136
137    def test_asyncio_run_reports_hanging_tasks_errors(self):
138        lo_task = None
139        call_exc_handler_mock = mock.Mock()
140
141        async def leftover():
142            try:
143                await asyncio.sleep(0.1)
144            except asyncio.CancelledError:
145                1 / 0
146
147        async def main():
148            loop = asyncio.get_running_loop()
149            loop.call_exception_handler = call_exc_handler_mock
150
151            nonlocal lo_task
152            lo_task = asyncio.create_task(leftover())
153            return 123
154
155        self.assertEqual(asyncio.run(main()), 123)
156        self.assertTrue(lo_task.done())
157
158        call_exc_handler_mock.assert_called_with({
159            'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
160            'task': lo_task,
161            'exception': test_utils.MockInstanceOf(ZeroDivisionError)
162        })
163
164    def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
165        spinner = None
166        lazyboy = None
167
168        class FancyExit(Exception):
169            pass
170
171        async def fidget():
172            while True:
173                yield 1
174                await asyncio.sleep(1)
175
176        async def spin():
177            nonlocal spinner
178            spinner = fidget()
179            try:
180                async for the_meaning_of_life in spinner:  # NoQA
181                    pass
182            except asyncio.CancelledError:
183                1 / 0
184
185        async def main():
186            loop = asyncio.get_running_loop()
187            loop.call_exception_handler = mock.Mock()
188
189            nonlocal lazyboy
190            lazyboy = asyncio.create_task(spin())
191            raise FancyExit
192
193        with self.assertRaises(FancyExit):
194            asyncio.run(main())
195
196        self.assertTrue(lazyboy.done())
197
198        self.assertIsNone(spinner.ag_frame)
199        self.assertFalse(spinner.ag_running)
200
201    def test_asyncio_run_set_event_loop(self):
202        #See https://github.com/python/cpython/issues/93896
203
204        async def main():
205            await asyncio.sleep(0)
206            return 42
207
208        policy = asyncio.get_event_loop_policy()
209        policy.set_event_loop = mock.Mock()
210        asyncio.run(main())
211        self.assertTrue(policy.set_event_loop.called)
212
213    def test_asyncio_run_without_uncancel(self):
214        # See https://github.com/python/cpython/issues/95097
215        class Task:
216            def __init__(self, loop, coro, **kwargs):
217                self._task = asyncio.Task(coro, loop=loop, **kwargs)
218
219            def cancel(self, *args, **kwargs):
220                return self._task.cancel(*args, **kwargs)
221
222            def add_done_callback(self, *args, **kwargs):
223                return self._task.add_done_callback(*args, **kwargs)
224
225            def remove_done_callback(self, *args, **kwargs):
226                return self._task.remove_done_callback(*args, **kwargs)
227
228            @property
229            def _asyncio_future_blocking(self):
230                return self._task._asyncio_future_blocking
231
232            def result(self, *args, **kwargs):
233                return self._task.result(*args, **kwargs)
234
235            def done(self, *args, **kwargs):
236                return self._task.done(*args, **kwargs)
237
238            def cancelled(self, *args, **kwargs):
239                return self._task.cancelled(*args, **kwargs)
240
241            def exception(self, *args, **kwargs):
242                return self._task.exception(*args, **kwargs)
243
244            def get_loop(self, *args, **kwargs):
245                return self._task.get_loop(*args, **kwargs)
246
247
248        async def main():
249            interrupt_self()
250            await asyncio.Event().wait()
251
252        def new_event_loop():
253            loop = self.new_loop()
254            loop.set_task_factory(Task)
255            return loop
256
257        asyncio.set_event_loop_policy(TestPolicy(new_event_loop))
258        with self.assertRaises(asyncio.CancelledError):
259            asyncio.run(main())
260
261
262class RunnerTests(BaseTest):
263
264    def test_non_debug(self):
265        with asyncio.Runner(debug=False) as runner:
266            self.assertFalse(runner.get_loop().get_debug())
267
268    def test_debug(self):
269        with asyncio.Runner(debug=True) as runner:
270            self.assertTrue(runner.get_loop().get_debug())
271
272    def test_custom_factory(self):
273        loop = mock.Mock()
274        with asyncio.Runner(loop_factory=lambda: loop) as runner:
275            self.assertIs(runner.get_loop(), loop)
276
277    def test_run(self):
278        async def f():
279            await asyncio.sleep(0)
280            return 'done'
281
282        with asyncio.Runner() as runner:
283            self.assertEqual('done', runner.run(f()))
284            loop = runner.get_loop()
285
286        with self.assertRaisesRegex(
287            RuntimeError,
288            "Runner is closed"
289        ):
290            runner.get_loop()
291
292        self.assertTrue(loop.is_closed())
293
294    def test_run_non_coro(self):
295        with asyncio.Runner() as runner:
296            with self.assertRaisesRegex(
297                ValueError,
298                "a coroutine was expected"
299            ):
300                runner.run(123)
301
302    def test_run_future(self):
303        with asyncio.Runner() as runner:
304            with self.assertRaisesRegex(
305                ValueError,
306                "a coroutine was expected"
307            ):
308                fut = runner.get_loop().create_future()
309                runner.run(fut)
310
311    def test_explicit_close(self):
312        runner = asyncio.Runner()
313        loop = runner.get_loop()
314        runner.close()
315        with self.assertRaisesRegex(
316                RuntimeError,
317                "Runner is closed"
318        ):
319            runner.get_loop()
320
321        self.assertTrue(loop.is_closed())
322
323    def test_double_close(self):
324        runner = asyncio.Runner()
325        loop = runner.get_loop()
326
327        runner.close()
328        self.assertTrue(loop.is_closed())
329
330        # the second call is no-op
331        runner.close()
332        self.assertTrue(loop.is_closed())
333
334    def test_second_with_block_raises(self):
335        ret = []
336
337        async def f(arg):
338            ret.append(arg)
339
340        runner = asyncio.Runner()
341        with runner:
342            runner.run(f(1))
343
344        with self.assertRaisesRegex(
345            RuntimeError,
346            "Runner is closed"
347        ):
348            with runner:
349                runner.run(f(2))
350
351        self.assertEqual([1], ret)
352
353    def test_run_keeps_context(self):
354        cvar = contextvars.ContextVar("cvar", default=-1)
355
356        async def f(val):
357            old = cvar.get()
358            await asyncio.sleep(0)
359            cvar.set(val)
360            return old
361
362        async def get_context():
363            return contextvars.copy_context()
364
365        with asyncio.Runner() as runner:
366            self.assertEqual(-1, runner.run(f(1)))
367            self.assertEqual(1, runner.run(f(2)))
368
369            self.assertEqual(2, runner.run(get_context()).get(cvar))
370
371    def test_recursive_run(self):
372        async def g():
373            pass
374
375        async def f():
376            runner.run(g())
377
378        with asyncio.Runner() as runner:
379            with self.assertWarnsRegex(
380                RuntimeWarning,
381                "coroutine .+ was never awaited",
382            ):
383                with self.assertRaisesRegex(
384                    RuntimeError,
385                    re.escape(
386                        "Runner.run() cannot be called from a running event loop"
387                    ),
388                ):
389                    runner.run(f())
390
391    def test_interrupt_call_soon(self):
392        # The only case when task is not suspended by waiting a future
393        # or another task
394        assert threading.current_thread() is threading.main_thread()
395
396        async def coro():
397            with self.assertRaises(asyncio.CancelledError):
398                while True:
399                    await asyncio.sleep(0)
400            raise asyncio.CancelledError()
401
402        with asyncio.Runner() as runner:
403            runner.get_loop().call_later(0.1, interrupt_self)
404            with self.assertRaises(KeyboardInterrupt):
405                runner.run(coro())
406
407    def test_interrupt_wait(self):
408        # interrupting when waiting a future cancels both future and main task
409        assert threading.current_thread() is threading.main_thread()
410
411        async def coro(fut):
412            with self.assertRaises(asyncio.CancelledError):
413                await fut
414            raise asyncio.CancelledError()
415
416        with asyncio.Runner() as runner:
417            fut = runner.get_loop().create_future()
418            runner.get_loop().call_later(0.1, interrupt_self)
419
420            with self.assertRaises(KeyboardInterrupt):
421                runner.run(coro(fut))
422
423            self.assertTrue(fut.cancelled())
424
425    def test_interrupt_cancelled_task(self):
426        # interrupting cancelled main task doesn't raise KeyboardInterrupt
427        assert threading.current_thread() is threading.main_thread()
428
429        async def subtask(task):
430            await asyncio.sleep(0)
431            task.cancel()
432            interrupt_self()
433
434        async def coro():
435            asyncio.create_task(subtask(asyncio.current_task()))
436            await asyncio.sleep(10)
437
438        with asyncio.Runner() as runner:
439            with self.assertRaises(asyncio.CancelledError):
440                runner.run(coro())
441
442    def test_signal_install_not_supported_ok(self):
443        # signal.signal() can throw if the "main thread" doensn't have signals enabled
444        assert threading.current_thread() is threading.main_thread()
445
446        async def coro():
447            pass
448
449        with asyncio.Runner() as runner:
450            with patch.object(
451                signal,
452                "signal",
453                side_effect=ValueError(
454                    "signal only works in main thread of the main interpreter"
455                )
456            ):
457                runner.run(coro())
458
459    def test_set_event_loop_called_once(self):
460        # See https://github.com/python/cpython/issues/95736
461        async def coro():
462            pass
463
464        policy = asyncio.get_event_loop_policy()
465        policy.set_event_loop = mock.Mock()
466        runner = asyncio.Runner()
467        runner.run(coro())
468        runner.run(coro())
469
470        self.assertEqual(1, policy.set_event_loop.call_count)
471        runner.close()
472
473
474if __name__ == '__main__':
475    unittest.main()
476