1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport unittest 3*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import Mock 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 8*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.device_interface import CudaInterface, DeviceGuard 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass TestDeviceGuard(torch._dynamo.test_case.TestCase): 13*da0073e9SAndroid Build Coastguard Worker """ 14*da0073e9SAndroid Build Coastguard Worker Unit tests for the DeviceGuard class using a mock DeviceInterface. 15*da0073e9SAndroid Build Coastguard Worker """ 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker def setUp(self): 18*da0073e9SAndroid Build Coastguard Worker super().setUp() 19*da0073e9SAndroid Build Coastguard Worker self.device_interface = Mock() 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker self.device_interface.exchange_device = Mock(return_value=0) 22*da0073e9SAndroid Build Coastguard Worker self.device_interface.maybe_exchange_device = Mock(return_value=1) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def test_device_guard(self): 25*da0073e9SAndroid Build Coastguard Worker device_guard = DeviceGuard(self.device_interface, 1) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker with device_guard as _: 28*da0073e9SAndroid Build Coastguard Worker self.device_interface.exchange_device.assert_called_once_with(1) 29*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.prev_idx, 0) 30*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.idx, 1) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker self.device_interface.maybe_exchange_device.assert_called_once_with(0) 33*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.prev_idx, 0) 34*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.idx, 1) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker def test_device_guard_no_index(self): 37*da0073e9SAndroid Build Coastguard Worker device_guard = DeviceGuard(self.device_interface, None) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker with device_guard as _: 40*da0073e9SAndroid Build Coastguard Worker self.device_interface.exchange_device.assert_not_called() 41*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.prev_idx, -1) 42*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.idx, None) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker self.device_interface.maybe_exchange_device.assert_not_called() 45*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.prev_idx, -1) 46*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.idx, None) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not TEST_CUDA, "No CUDA available.") 50*da0073e9SAndroid Build Coastguard Workerclass TestCUDADeviceGuard(torch._dynamo.test_case.TestCase): 51*da0073e9SAndroid Build Coastguard Worker """ 52*da0073e9SAndroid Build Coastguard Worker Unit tests for the DeviceGuard class using a CudaInterface. 53*da0073e9SAndroid Build Coastguard Worker """ 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def setUp(self): 56*da0073e9SAndroid Build Coastguard Worker super().setUp() 57*da0073e9SAndroid Build Coastguard Worker self.device_interface = CudaInterface 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") 60*da0073e9SAndroid Build Coastguard Worker def test_device_guard(self): 61*da0073e9SAndroid Build Coastguard Worker current_device = torch.cuda.current_device() 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker device_guard = DeviceGuard(self.device_interface, 1) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker with device_guard as _: 66*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_device(), 1) 67*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.prev_idx, 0) 68*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.idx, 1) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_device(), current_device) 71*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.prev_idx, 0) 72*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.idx, 1) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def test_device_guard_no_index(self): 75*da0073e9SAndroid Build Coastguard Worker current_device = torch.cuda.current_device() 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker device_guard = DeviceGuard(self.device_interface, None) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker with device_guard as _: 80*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cuda.current_device(), current_device) 81*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.prev_idx, -1) 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.idx, None) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.prev_idx, -1) 85*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_guard.idx, None) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 89*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker run_tests() 92