1# Owner(s): ["module: dynamo"] 2import unittest 3from unittest.mock import Mock 4 5import torch 6import torch._dynamo.test_case 7import torch._dynamo.testing 8from torch._dynamo.device_interface import CudaInterface, DeviceGuard 9from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU 10 11 12class TestDeviceGuard(torch._dynamo.test_case.TestCase): 13 """ 14 Unit tests for the DeviceGuard class using a mock DeviceInterface. 15 """ 16 17 def setUp(self): 18 super().setUp() 19 self.device_interface = Mock() 20 21 self.device_interface.exchange_device = Mock(return_value=0) 22 self.device_interface.maybe_exchange_device = Mock(return_value=1) 23 24 def test_device_guard(self): 25 device_guard = DeviceGuard(self.device_interface, 1) 26 27 with device_guard as _: 28 self.device_interface.exchange_device.assert_called_once_with(1) 29 self.assertEqual(device_guard.prev_idx, 0) 30 self.assertEqual(device_guard.idx, 1) 31 32 self.device_interface.maybe_exchange_device.assert_called_once_with(0) 33 self.assertEqual(device_guard.prev_idx, 0) 34 self.assertEqual(device_guard.idx, 1) 35 36 def test_device_guard_no_index(self): 37 device_guard = DeviceGuard(self.device_interface, None) 38 39 with device_guard as _: 40 self.device_interface.exchange_device.assert_not_called() 41 self.assertEqual(device_guard.prev_idx, -1) 42 self.assertEqual(device_guard.idx, None) 43 44 self.device_interface.maybe_exchange_device.assert_not_called() 45 self.assertEqual(device_guard.prev_idx, -1) 46 self.assertEqual(device_guard.idx, None) 47 48 49@unittest.skipIf(not TEST_CUDA, "No CUDA available.") 50class TestCUDADeviceGuard(torch._dynamo.test_case.TestCase): 51 """ 52 Unit tests for the DeviceGuard class using a CudaInterface. 53 """ 54 55 def setUp(self): 56 super().setUp() 57 self.device_interface = CudaInterface 58 59 @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") 60 def test_device_guard(self): 61 current_device = torch.cuda.current_device() 62 63 device_guard = DeviceGuard(self.device_interface, 1) 64 65 with device_guard as _: 66 self.assertEqual(torch.cuda.current_device(), 1) 67 self.assertEqual(device_guard.prev_idx, 0) 68 self.assertEqual(device_guard.idx, 1) 69 70 self.assertEqual(torch.cuda.current_device(), current_device) 71 self.assertEqual(device_guard.prev_idx, 0) 72 self.assertEqual(device_guard.idx, 1) 73 74 def test_device_guard_no_index(self): 75 current_device = torch.cuda.current_device() 76 77 device_guard = DeviceGuard(self.device_interface, None) 78 79 with device_guard as _: 80 self.assertEqual(torch.cuda.current_device(), current_device) 81 self.assertEqual(device_guard.prev_idx, -1) 82 self.assertEqual(device_guard.idx, None) 83 84 self.assertEqual(device_guard.prev_idx, -1) 85 self.assertEqual(device_guard.idx, None) 86 87 88if __name__ == "__main__": 89 from torch._dynamo.test_case import run_tests 90 91 run_tests() 92