1# Owner(s): ["oncall: profiler"] 2 3# if tqdm is not shutdown properly, it will leave the monitor thread alive. 4# This causes an issue in the multithreading test because we check all events 5# in that test with their tids. The events that correspond to these lingering 6# threads all have TID of (uint64_t)(-1) which is invalid. 7# The work around is turnning off monitoring thread when tqdm is loaded. 8# Since these are unit tests, it is safe to turn off monitor thread. 9try: 10 import tqdm 11 12 tqdm.tqdm.monitor_interval = 0 13except ImportError: 14 pass 15 16import json 17import sys 18import tempfile 19import unittest 20from typing import Any, Dict, List 21 22import torch 23import torch.nn as nn 24from torch import _dynamo as torchdynamo 25from torch.autograd import ( 26 _record_function_with_args_enter, 27 _record_function_with_args_exit, 28) 29from torch.profiler import ( 30 ExecutionTraceObserver, 31 kineto_available, 32 profile, 33 record_function, 34 supported_activities, 35) 36from torch.testing._internal.common_cuda import TEST_CUDA 37from torch.testing._internal.common_utils import ( 38 IS_WINDOWS, 39 run_tests, 40 skipIfTorchDynamo, 41 TestCase, 42) 43from torch.utils._triton import has_triton 44 45 46Json = Dict[str, Any] 47 48 49class TestExecutionTrace(TestCase): 50 def payload(self, use_cuda=False): 51 u = torch.randn(3, 4, 5, requires_grad=True) 52 with record_function("## TEST 1 ##", "1, 2, 3"): 53 inf_val = float("inf") 54 neg_inf_val = float("-inf") 55 nan_val = float("nan") 56 rf_handle = _record_function_with_args_enter( 57 "## TEST 2 ##", 58 1, 59 False, 60 2.5, 61 [u, u], 62 (u, u), 63 "hello", 64 u, 65 inf_val, 66 neg_inf_val, 67 nan_val, 68 ) 69 x = torch.randn(10, 10, requires_grad=True) 70 if use_cuda: 71 x = x.cuda() 72 y = torch.randn(10, 10, requires_grad=True) 73 if use_cuda: 74 y = y.cuda() 75 z = x + y + x * y + x * y 76 z.backward(z) 77 gelu = nn.GELU() 78 m = torch.randn(2) 79 _ = gelu(m) 80 if use_cuda: 81 z = z.cpu() 82 _record_function_with_args_exit(rf_handle) 83 84 def get_execution_trace_root(self, output_file_name) -> Json: 85 nodes = [] 86 with open(output_file_name) as f: 87 et_graph = json.load(f) 88 assert "nodes" in et_graph 89 nodes = et_graph["nodes"] 90 return nodes 91 92 def get_execution_trace_rf_ids(self, nodes: List[Json]) -> List[int]: 93 """Returns a sorted list of rf_id (record function ids) in execution trace""" 94 95 def get_rf_id(node): 96 attrs = node["attrs"] 97 for a in attrs: 98 if a["name"] == "rf_id": 99 return a["value"] 100 return None 101 102 rf_ids_ = ( 103 get_rf_id(n) 104 for n in nodes 105 if n["name"] != "[pytorch|profiler|execution_trace|process]" 106 and n["name"] != "[pytorch|profiler|execution_trace|thread]" 107 ) 108 return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None) 109 110 def get_kineto_rf_ids(self, events: List[Json]) -> List[int]: 111 """Returns a sorted list of Record function IDs for CPU operators and user annotations""" 112 ops_and_annotations = ( 113 e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"] 114 ) 115 return sorted( 116 e.get("args", {}).get("Record function id", -1) for e in ops_and_annotations 117 ) 118 119 @unittest.skipIf(not kineto_available(), "Kineto is required") 120 def test_execution_trace_with_kineto(self): 121 trace_called_num = 0 122 123 def trace_handler(p): 124 nonlocal trace_called_num 125 trace_called_num += 1 126 127 use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 128 # Create a temp file to save execution trace and kineto data. 129 fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) 130 fp.close() 131 kt = tempfile.NamedTemporaryFile( 132 mode="w+t", suffix=".kineto.json", delete=False 133 ) 134 kt.close() 135 136 with profile( 137 activities=supported_activities(), 138 schedule=torch.profiler.schedule( 139 skip_first=3, wait=1, warmup=1, active=2, repeat=1 140 ), 141 on_trace_ready=trace_handler, 142 execution_trace_observer=( 143 ExecutionTraceObserver().register_callback(fp.name) 144 ), 145 ) as p: 146 for idx in range(10): 147 with record_function(f"## LOOP {idx} ##"): 148 self.payload(use_cuda=use_cuda) 149 p.step() 150 self.assertEqual(fp.name, p.execution_trace_observer.get_output_file_path()) 151 152 # Uncomment for debugging 153 # print("Output kineto = ", kt.name) 154 # print("Output ET = ", fp.name) 155 156 p.export_chrome_trace(kt.name) 157 self.assertEqual(trace_called_num, 1) 158 159 nodes = self.get_execution_trace_root(fp.name) 160 loop_count = 0 161 found_root_node = False 162 for n in nodes: 163 assert "name" in n 164 if "[pytorch|profiler|execution_trace|process]" in n["name"]: 165 found_root_node = True 166 if n["name"].startswith("## LOOP "): 167 loop_count += 1 168 self.assertTrue(found_root_node) 169 # Since profiler trace is active for 2 iterations 170 self.assertEqual(loop_count, 2) 171 172 # Compare the collected Execution Trace and Kineto Trace 173 # in terms of record func ID (rf_id) and External IDs 174 # both of these should match for the same trace window. 175 176 with open(kt.name) as f: 177 kineto = json.load(f) 178 events = kineto["traceEvents"] 179 180 # Look up rf_ids in both Execution and Kineto trace as two lists. 181 rf_ids_et = self.get_execution_trace_rf_ids(nodes) 182 rf_ids_kineto = self.get_kineto_rf_ids(events) 183 184 self.assertCountEqual(rf_ids_et, rf_ids_kineto) 185 self.assertListEqual( 186 rf_ids_et, 187 rf_ids_kineto, 188 msg=f"ET and kineto rf_id should exactly match\n" 189 f" rf_ids_et = {rf_ids_et}\n" 190 f" rf_ids_kineto = {rf_ids_kineto}\n", 191 ) 192 193 def test_execution_trace_alone(self): 194 use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 195 # Create a temp file to save execution trace data. 196 fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) 197 fp.close() 198 expected_loop_events = 0 199 200 et = ExecutionTraceObserver().register_callback(fp.name) 201 202 et.start() 203 for idx in range(5): 204 expected_loop_events += 1 205 with record_function(f"## LOOP {idx} ##"): 206 self.payload(use_cuda=use_cuda) 207 et.stop() 208 209 assert fp.name == et.get_output_file_path() 210 et.unregister_callback() 211 nodes = self.get_execution_trace_root(fp.name) 212 loop_count = 0 213 # Expected tensor object tuple size, in th form of: 214 # [tensor_id, storage_id, offset, numel, itemsize, device_str] 215 tensor_tuple_size = 6 216 found_root_node = False 217 for n in nodes: 218 assert "name" in n 219 if "[pytorch|profiler|execution_trace|process]" in n["name"]: 220 found_root_node = True 221 if n["name"].startswith("## LOOP "): 222 loop_count += 1 223 # Check if tensor tuple representation size is correct. 224 if n["name"] == "## TEST 2 ##": 225 assert len(n["inputs"]["values"][3][0]) == tensor_tuple_size 226 assert found_root_node 227 assert loop_count == expected_loop_events 228 229 @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") 230 @unittest.skipIf( 231 sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" 232 ) 233 @unittest.skipIf(not TEST_CUDA or not has_triton(), "need CUDA and triton to run") 234 def test_execution_trace_with_pt2(self): 235 @torchdynamo.optimize("inductor") 236 def fn(a, b, c): 237 x = torch.nn.functional.linear(a, b) 238 x = x + c 239 return x.cos() 240 241 a, b, c = (torch.randn(4, 4, requires_grad=True).to("cuda") for _ in range(3)) 242 243 inputs = [a, b, c] 244 with torch._inductor.config.patch(compile_threads=1): 245 fn(*inputs) 246 247 # Create a temp file to save execution trace data. 248 fp = tempfile.NamedTemporaryFile("w+t", suffix="_et.json", delete=False) 249 fp.close() 250 251 with profile( 252 activities=torch.profiler.supported_activities(), 253 record_shapes=True, 254 schedule=torch.profiler.schedule( 255 skip_first=3, wait=1, warmup=1, active=2, repeat=1 256 ), 257 execution_trace_observer=( 258 ExecutionTraceObserver().register_callback(fp.name) 259 ), 260 ) as p: 261 for idx in range(10): 262 with record_function(f"## LOOP {idx} ##"): 263 fn(*inputs) 264 p.step() 265 266 nodes = self.get_execution_trace_root(fp.name) 267 found_captured_triton_kernel_node = False 268 for n in nodes: 269 assert "name" in n 270 if "triton_" in n["name"]: 271 for attr in n["attrs"]: 272 if attr["name"] == "kernel_file" and attr["value"] != "": 273 found_captured_triton_kernel_node = True 274 assert len(n["inputs"]["values"]) > 0 275 assert len(n["outputs"]["values"]) == 0 276 assert found_captured_triton_kernel_node 277 278 def test_execution_trace_start_stop(self): 279 use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 280 # Create a temp file to save execution trace data. 281 fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) 282 fp.close() 283 expected_loop_events = 0 284 et = ExecutionTraceObserver().register_callback(fp.name) 285 for idx in range(10): 286 if idx == 3: 287 et.start() 288 elif idx == 5: 289 et.stop() 290 elif idx == 8: 291 et.start() 292 elif idx == 9: 293 et.stop() 294 if et._execution_trace_running: 295 expected_loop_events += 1 296 with record_function(f"## LOOP {idx} ##"): 297 self.payload(use_cuda=use_cuda) 298 299 assert fp.name == et.get_output_file_path() 300 et.unregister_callback() 301 nodes = self.get_execution_trace_root(fp.name) 302 loop_count = 0 303 found_root_node = False 304 for n in nodes: 305 assert "name" in n 306 if "[pytorch|profiler|execution_trace|process]" in n["name"]: 307 found_root_node = True 308 if n["name"].startswith("## LOOP "): 309 loop_count += 1 310 assert found_root_node 311 assert loop_count == expected_loop_events 312 313 def test_execution_trace_repeat_in_loop(self): 314 use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 315 iter_list = {3, 4, 6, 8} 316 expected_loop_events = len(iter_list) 317 output_files = [] 318 for idx in range(10): 319 if idx in iter_list: 320 # Create a temp file to save execution trace data. 321 fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) 322 fp.close() 323 output_files.append(fp.name) 324 et = ExecutionTraceObserver().register_callback(fp.name) 325 et.start() 326 with record_function(f"## LOOP {idx} ##"): 327 self.payload(use_cuda=use_cuda) 328 if idx in iter_list: 329 et.stop() 330 et.unregister_callback() 331 332 event_count = 0 333 for et_file in output_files: 334 nodes = self.get_execution_trace_root(et_file) 335 found_root_node = False 336 for n in nodes: 337 assert "name" in n 338 if "[pytorch|profiler|execution_trace|process]" in n["name"]: 339 assert n["id"] == 1 340 found_root_node = True 341 if n["name"].startswith("## LOOP "): 342 event_count += 1 343 assert found_root_node 344 assert event_count == expected_loop_events 345 346 def test_execution_trace_no_capture(self): 347 fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) 348 fp.close() 349 et = ExecutionTraceObserver().register_callback(fp.name) 350 351 assert fp.name == et.get_output_file_path() 352 et.unregister_callback() 353 nodes = self.get_execution_trace_root(fp.name) 354 for n in nodes: 355 assert "name" in n 356 if "[pytorch|profiler|execution_trace|process]" in n["name"]: 357 found_root_node = True 358 assert found_root_node 359 360 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500") 361 def test_execution_trace_nested_tensor(self): 362 fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) 363 fp.close() 364 365 observer = ExecutionTraceObserver().register_callback(fp.name) 366 367 def fn(nt): 368 return nt.sin().cos() 369 370 with torch.profiler.profile(execution_trace_observer=observer) as prof: 371 for i in range(3): 372 values = torch.rand((8 + i, 4 + i)) 373 offsets = torch.tensor([0, 2, 4, 6, 8 + i]) 374 nt = torch.nested.nested_tensor_from_jagged(values, offsets) 375 fn(nt) 376 377 nodes = self.get_execution_trace_root(fp.name) 378 found_cos = False 379 for n in nodes: 380 assert "name" in n 381 if "cos" in n["name"]: 382 found_cos = True 383 assert found_cos 384 385 386if __name__ == "__main__": 387 run_tests() 388