1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport sys 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Workerimport unittest.mock 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.cuda._gpu_trace as gpu_trace 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker# NOTE: Each test needs to be run in a brand new process, to reset the registered hooks 13*da0073e9SAndroid Build Coastguard Worker# and make sure the CUDA streams are initialized for each test that uses them. 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerif not TEST_CUDA: 16*da0073e9SAndroid Build Coastguard Worker print("CUDA not available, skipping tests", file=sys.stderr) 17*da0073e9SAndroid Build Coastguard Worker TestCase = NoTest # noqa: F811 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker@torch.testing._internal.common_utils.markDynamoStrictTest 21*da0073e9SAndroid Build Coastguard Workerclass TestCudaTrace(TestCase): 22*da0073e9SAndroid Build Coastguard Worker def setUp(self): 23*da0073e9SAndroid Build Coastguard Worker torch._C._activate_gpu_trace() 24*da0073e9SAndroid Build Coastguard Worker self.mock = unittest.mock.MagicMock() 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker def test_event_creation_callback(self): 27*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_creation(self.mock) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event() 30*da0073e9SAndroid Build Coastguard Worker event.record() 31*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(event._as_parameter_.value) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker def test_event_deletion_callback(self): 34*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_deletion(self.mock) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event() 37*da0073e9SAndroid Build Coastguard Worker event.record() 38*da0073e9SAndroid Build Coastguard Worker event_id = event._as_parameter_.value 39*da0073e9SAndroid Build Coastguard Worker del event 40*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(event_id) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def test_event_record_callback(self): 43*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_record(self.mock) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event() 46*da0073e9SAndroid Build Coastguard Worker event.record() 47*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with( 48*da0073e9SAndroid Build Coastguard Worker event._as_parameter_.value, torch.cuda.default_stream().cuda_stream 49*da0073e9SAndroid Build Coastguard Worker ) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def test_event_wait_callback(self): 52*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_wait(self.mock) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event() 55*da0073e9SAndroid Build Coastguard Worker event.record() 56*da0073e9SAndroid Build Coastguard Worker event.wait() 57*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with( 58*da0073e9SAndroid Build Coastguard Worker event._as_parameter_.value, torch.cuda.default_stream().cuda_stream 59*da0073e9SAndroid Build Coastguard Worker ) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def test_memory_allocation_callback(self): 62*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_memory_allocation(self.mock) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(10, 4, device="cuda") 65*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(tensor.data_ptr()) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker def test_memory_deallocation_callback(self): 68*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_memory_deallocation(self.mock) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(3, 8, device="cuda") 71*da0073e9SAndroid Build Coastguard Worker data_ptr = tensor.data_ptr() 72*da0073e9SAndroid Build Coastguard Worker del tensor 73*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(data_ptr) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def test_stream_creation_callback(self): 76*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_stream_creation(self.mock) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker # see Note [HIP Lazy Streams] 79*da0073e9SAndroid Build Coastguard Worker if torch.version.hip: 80*da0073e9SAndroid Build Coastguard Worker user_stream = torch.cuda.Stream() 81*da0073e9SAndroid Build Coastguard Worker with torch.cuda.stream(user_stream): 82*da0073e9SAndroid Build Coastguard Worker tensor = torch.ones(5, device="cuda") 83*da0073e9SAndroid Build Coastguard Worker else: 84*da0073e9SAndroid Build Coastguard Worker torch.cuda.Stream() 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called() 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker def test_device_synchronization_callback(self): 89*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_device_synchronization(self.mock) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 92*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called() 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker def test_stream_synchronization_callback(self): 95*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_stream_synchronization(self.mock) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker stream = torch.cuda.Stream() 98*da0073e9SAndroid Build Coastguard Worker stream.synchronize() 99*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(stream.cuda_stream) 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def test_event_synchronization_callback(self): 102*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_event_synchronization(self.mock) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker event = torch.cuda.Event() 105*da0073e9SAndroid Build Coastguard Worker event.record() 106*da0073e9SAndroid Build Coastguard Worker event.synchronize() 107*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(event._as_parameter_.value) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def test_memcpy_synchronization(self): 110*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_stream_synchronization(self.mock) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(5, device="cuda") 113*da0073e9SAndroid Build Coastguard Worker tensor.nonzero() 114*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(torch.cuda.default_stream().cuda_stream) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def test_all_trace_callbacks_called(self): 117*da0073e9SAndroid Build Coastguard Worker other = unittest.mock.MagicMock() 118*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_memory_allocation(self.mock) 119*da0073e9SAndroid Build Coastguard Worker gpu_trace.register_callback_for_memory_allocation(other) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(10, 4, device="cuda") 122*da0073e9SAndroid Build Coastguard Worker self.mock.assert_called_once_with(tensor.data_ptr()) 123*da0073e9SAndroid Build Coastguard Worker other.assert_called_once_with(tensor.data_ptr()) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 127*da0073e9SAndroid Build Coastguard Worker run_tests() 128