xref: /aosp_15_r20/external/pytorch/test/dynamo/test_deviceguard.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport unittest
3*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import Mock
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
8*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.device_interface import CudaInterface, DeviceGuard
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerclass TestDeviceGuard(torch._dynamo.test_case.TestCase):
13*da0073e9SAndroid Build Coastguard Worker    """
14*da0073e9SAndroid Build Coastguard Worker    Unit tests for the DeviceGuard class using a mock DeviceInterface.
15*da0073e9SAndroid Build Coastguard Worker    """
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
18*da0073e9SAndroid Build Coastguard Worker        super().setUp()
19*da0073e9SAndroid Build Coastguard Worker        self.device_interface = Mock()
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker        self.device_interface.exchange_device = Mock(return_value=0)
22*da0073e9SAndroid Build Coastguard Worker        self.device_interface.maybe_exchange_device = Mock(return_value=1)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def test_device_guard(self):
25*da0073e9SAndroid Build Coastguard Worker        device_guard = DeviceGuard(self.device_interface, 1)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker        with device_guard as _:
28*da0073e9SAndroid Build Coastguard Worker            self.device_interface.exchange_device.assert_called_once_with(1)
29*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(device_guard.prev_idx, 0)
30*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(device_guard.idx, 1)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        self.device_interface.maybe_exchange_device.assert_called_once_with(0)
33*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_guard.prev_idx, 0)
34*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_guard.idx, 1)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    def test_device_guard_no_index(self):
37*da0073e9SAndroid Build Coastguard Worker        device_guard = DeviceGuard(self.device_interface, None)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        with device_guard as _:
40*da0073e9SAndroid Build Coastguard Worker            self.device_interface.exchange_device.assert_not_called()
41*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(device_guard.prev_idx, -1)
42*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(device_guard.idx, None)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        self.device_interface.maybe_exchange_device.assert_not_called()
45*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_guard.prev_idx, -1)
46*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_guard.idx, None)
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "No CUDA available.")
50*da0073e9SAndroid Build Coastguard Workerclass TestCUDADeviceGuard(torch._dynamo.test_case.TestCase):
51*da0073e9SAndroid Build Coastguard Worker    """
52*da0073e9SAndroid Build Coastguard Worker    Unit tests for the DeviceGuard class using a CudaInterface.
53*da0073e9SAndroid Build Coastguard Worker    """
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
56*da0073e9SAndroid Build Coastguard Worker        super().setUp()
57*da0073e9SAndroid Build Coastguard Worker        self.device_interface = CudaInterface
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU")
60*da0073e9SAndroid Build Coastguard Worker    def test_device_guard(self):
61*da0073e9SAndroid Build Coastguard Worker        current_device = torch.cuda.current_device()
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        device_guard = DeviceGuard(self.device_interface, 1)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        with device_guard as _:
66*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.current_device(), 1)
67*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(device_guard.prev_idx, 0)
68*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(device_guard.idx, 1)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.cuda.current_device(), current_device)
71*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_guard.prev_idx, 0)
72*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_guard.idx, 1)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def test_device_guard_no_index(self):
75*da0073e9SAndroid Build Coastguard Worker        current_device = torch.cuda.current_device()
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker        device_guard = DeviceGuard(self.device_interface, None)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        with device_guard as _:
80*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cuda.current_device(), current_device)
81*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(device_guard.prev_idx, -1)
82*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(device_guard.idx, None)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_guard.prev_idx, -1)
85*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_guard.idx, None)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
89*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    run_tests()
92