1"""Tests for asyncio/threads.py"""
2
3import asyncio
4import unittest
5
6from contextvars import ContextVar
7from unittest import mock
8
9
10def tearDownModule():
11    asyncio.set_event_loop_policy(None)
12
13
14class ToThreadTests(unittest.IsolatedAsyncioTestCase):
15    async def test_to_thread(self):
16        result = await asyncio.to_thread(sum, [40, 2])
17        self.assertEqual(result, 42)
18
19    async def test_to_thread_exception(self):
20        def raise_runtime():
21            raise RuntimeError("test")
22
23        with self.assertRaisesRegex(RuntimeError, "test"):
24            await asyncio.to_thread(raise_runtime)
25
26    async def test_to_thread_once(self):
27        func = mock.Mock()
28
29        await asyncio.to_thread(func)
30        func.assert_called_once()
31
32    async def test_to_thread_concurrent(self):
33        func = mock.Mock()
34
35        futs = []
36        for _ in range(10):
37            fut = asyncio.to_thread(func)
38            futs.append(fut)
39        await asyncio.gather(*futs)
40
41        self.assertEqual(func.call_count, 10)
42
43    async def test_to_thread_args_kwargs(self):
44        # Unlike run_in_executor(), to_thread() should directly accept kwargs.
45        func = mock.Mock()
46
47        await asyncio.to_thread(func, 'test', something=True)
48
49        func.assert_called_once_with('test', something=True)
50
51    async def test_to_thread_contextvars(self):
52        test_ctx = ContextVar('test_ctx')
53
54        def get_ctx():
55            return test_ctx.get()
56
57        test_ctx.set('parrot')
58        result = await asyncio.to_thread(get_ctx)
59
60        self.assertEqual(result, 'parrot')
61
62
63if __name__ == "__main__":
64    unittest.main()
65