1# Owner(s): ["oncall: profiler"] 2 3# if tqdm is not shutdown properly, it will leave the monitor thread alive. 4# This causes an issue in the multithreading test because we check all events 5# in that test with their tids. The events that correspond to these lingering 6# threads all have TID of (uint64_t)(-1) which is invalid. 7# The work around is turnning off monitoring thread when tqdm is loaded. 8# Since these are unit tests, it is safe to turn off monitor thread. 9try: 10 import tqdm 11 12 tqdm.tqdm.monitor_interval = 0 13except ImportError: 14 None 15 16from typing import Any, Dict 17 18import torch 19import torch.optim 20import torch.utils.data 21import torch.utils.data.datapipes as dp 22from torch.autograd import ( 23 _record_function_with_args_enter, 24 _record_function_with_args_exit, 25) 26from torch.autograd.profiler import profile as _profile 27from torch.profiler import kineto_available, record_function 28from torch.testing._internal.common_utils import run_tests, TestCase 29 30 31Json = Dict[str, Any] 32 33 34class TestRecordFunction(TestCase): 35 def _record_function_with_param(self): 36 u = torch.randn(3, 4, 5, requires_grad=True) 37 with _profile( 38 with_stack=True, use_kineto=kineto_available(), record_shapes=True 39 ) as prof: 40 with record_function("## TEST 1 ##", "1, 2, 3"): 41 rf_handle = _record_function_with_args_enter( 42 "## TEST 2 ##", 1, False, 2.5, [u, u], "hello", u 43 ) 44 _record_function_with_args_exit(rf_handle) 45 with record_function("## TEST 3 ##"): 46 rf_handle = _record_function_with_args_enter("## TEST 4 ##") 47 _record_function_with_args_exit(rf_handle) 48 return prof 49 50 def test_record_function(self): 51 prof_result = self._record_function_with_param() 52 found_test_1 = False 53 found_test_2 = False 54 found_test_3 = False 55 found_test_4 = False 56 for e in prof_result.function_events: 57 if "## TEST 1 ##" == e.name: 58 found_test_1 = True 59 self.assertTrue(e.input_shapes == [[]]) 60 elif "## TEST 2 ##" == e.name: 61 found_test_2 = True 62 self.assertTrue(e.input_shapes == [[], [], [], [], [], [3, 4, 5]]) 63 elif "## TEST 3 ##" == e.name: 64 found_test_3 = True 65 self.assertTrue(e.input_shapes == []) 66 elif "## TEST 4 ##" == e.name: 67 found_test_4 = True 68 self.assertTrue(e.input_shapes == []) 69 self.assertTrue(found_test_1) 70 self.assertTrue(found_test_2) 71 self.assertTrue(found_test_3) 72 self.assertTrue(found_test_4) 73 74 def test_datapipe_with_record_function(self): 75 with _profile( 76 with_stack=True, use_kineto=kineto_available(), record_shapes=True 77 ) as prof: 78 input_dp1 = dp.iter.IterableWrapper(range(4)) 79 input_dp2 = dp.iter.IterableWrapper(range(4, 8)) 80 input_dp3 = dp.iter.IterableWrapper(range(8, 12)) 81 output_dp = input_dp1.mux(input_dp2, input_dp3) 82 output = list(output_dp) 83 84 has_iter = False 85 has_mux = False 86 for e in prof.function_events: 87 if has_iter and has_mux: 88 break 89 90 if not has_iter and "IterableWrapper" in e.name: 91 has_iter = True 92 if not has_mux and "Multiplexer" in e.name: 93 has_mux = True 94 self.assertTrue(has_iter) 95 self.assertTrue(has_mux) 96 97 def test_datapipe_delegation_with_profiler(self): 98 class IDPIterator(torch.utils.data.IterDataPipe): 99 def __init__(self) -> None: 100 self.data = list(range(10)) 101 self._idx = 0 102 103 def __iter__(self): 104 return self 105 106 def __next__(self): 107 if self._idx >= 10: 108 self._idx = 0 109 raise StopIteration 110 self._idx += 1 111 return self.data[self._idx - 1] 112 113 def get_value(self, idx): 114 return self.data[idx] 115 116 dp1 = IDPIterator() # The object itself is an iterator 117 self.assertEqual(5, dp1.get_value(5)) 118 it_dp1 = iter(dp1) # This creates the 1st iterator 119 self.assertEqual(5, it_dp1.get_value(5)) # type: ignore[attr-defined] 120 self.assertEqual(list(range(10)), list(it_dp1)) 121 122 class IDPDelegator(torch.utils.data.IterDataPipe): 123 def __init__(self, datapipe): 124 self.datapipe = datapipe 125 126 def __iter__(self): 127 return iter(self.datapipe) 128 129 dp2 = IDPDelegator(dp1) 130 it_dp2 = iter(dp2) 131 self.assertEqual(5, it_dp2.get_value(5)) 132 self.assertEqual(list(range(10)), list(it_dp2)) 133 134 def test_datapipe_with_record_function_fork(self): 135 with _profile( 136 with_stack=True, use_kineto=kineto_available(), record_shapes=True 137 ) as prof: 138 input_dp = dp.iter.IterableWrapper(range(10)) 139 dp1, dp2, dp3 = input_dp.fork(num_instances=3) 140 output1 = list(dp1) 141 has_iter = False 142 has_child = False 143 for e in prof.function_events: 144 if has_iter and has_child: 145 break 146 147 if not has_iter and "IterableWrapper" in e.name: 148 has_iter = True 149 if not has_child and "_ChildDataPipe" in e.name: 150 has_child = True 151 self.assertTrue(has_iter) 152 self.assertTrue(has_child) 153 154 155if __name__ == "__main__": 156 run_tests() 157