xref: /aosp_15_r20/external/pytorch/test/dynamo/test_deviceguard.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import unittest
3from unittest.mock import Mock
4
5import torch
6import torch._dynamo.test_case
7import torch._dynamo.testing
8from torch._dynamo.device_interface import CudaInterface, DeviceGuard
9from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
10
11
12class TestDeviceGuard(torch._dynamo.test_case.TestCase):
13    """
14    Unit tests for the DeviceGuard class using a mock DeviceInterface.
15    """
16
17    def setUp(self):
18        super().setUp()
19        self.device_interface = Mock()
20
21        self.device_interface.exchange_device = Mock(return_value=0)
22        self.device_interface.maybe_exchange_device = Mock(return_value=1)
23
24    def test_device_guard(self):
25        device_guard = DeviceGuard(self.device_interface, 1)
26
27        with device_guard as _:
28            self.device_interface.exchange_device.assert_called_once_with(1)
29            self.assertEqual(device_guard.prev_idx, 0)
30            self.assertEqual(device_guard.idx, 1)
31
32        self.device_interface.maybe_exchange_device.assert_called_once_with(0)
33        self.assertEqual(device_guard.prev_idx, 0)
34        self.assertEqual(device_guard.idx, 1)
35
36    def test_device_guard_no_index(self):
37        device_guard = DeviceGuard(self.device_interface, None)
38
39        with device_guard as _:
40            self.device_interface.exchange_device.assert_not_called()
41            self.assertEqual(device_guard.prev_idx, -1)
42            self.assertEqual(device_guard.idx, None)
43
44        self.device_interface.maybe_exchange_device.assert_not_called()
45        self.assertEqual(device_guard.prev_idx, -1)
46        self.assertEqual(device_guard.idx, None)
47
48
49@unittest.skipIf(not TEST_CUDA, "No CUDA available.")
50class TestCUDADeviceGuard(torch._dynamo.test_case.TestCase):
51    """
52    Unit tests for the DeviceGuard class using a CudaInterface.
53    """
54
55    def setUp(self):
56        super().setUp()
57        self.device_interface = CudaInterface
58
59    @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU")
60    def test_device_guard(self):
61        current_device = torch.cuda.current_device()
62
63        device_guard = DeviceGuard(self.device_interface, 1)
64
65        with device_guard as _:
66            self.assertEqual(torch.cuda.current_device(), 1)
67            self.assertEqual(device_guard.prev_idx, 0)
68            self.assertEqual(device_guard.idx, 1)
69
70        self.assertEqual(torch.cuda.current_device(), current_device)
71        self.assertEqual(device_guard.prev_idx, 0)
72        self.assertEqual(device_guard.idx, 1)
73
74    def test_device_guard_no_index(self):
75        current_device = torch.cuda.current_device()
76
77        device_guard = DeviceGuard(self.device_interface, None)
78
79        with device_guard as _:
80            self.assertEqual(torch.cuda.current_device(), current_device)
81            self.assertEqual(device_guard.prev_idx, -1)
82            self.assertEqual(device_guard.idx, None)
83
84        self.assertEqual(device_guard.prev_idx, -1)
85        self.assertEqual(device_guard.idx, None)
86
87
88if __name__ == "__main__":
89    from torch._dynamo.test_case import run_tests
90
91    run_tests()
92