1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: profiler"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport gc 5*da0073e9SAndroid Build Coastguard Workerimport json 6*da0073e9SAndroid Build Coastguard Workerimport mmap 7*da0073e9SAndroid Build Coastguard Workerimport os 8*da0073e9SAndroid Build Coastguard Workerimport pickle 9*da0073e9SAndroid Build Coastguard Workerimport random 10*da0073e9SAndroid Build Coastguard Workerimport re 11*da0073e9SAndroid Build Coastguard Workerimport struct 12*da0073e9SAndroid Build Coastguard Workerimport subprocess 13*da0073e9SAndroid Build Coastguard Workerimport sys 14*da0073e9SAndroid Build Coastguard Workerimport tempfile 15*da0073e9SAndroid Build Coastguard Workerimport threading 16*da0073e9SAndroid Build Coastguard Workerimport time 17*da0073e9SAndroid Build Coastguard Workerimport unittest 18*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field 19*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Optional 20*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerimport expecttest 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerimport torch 25*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 26*da0073e9SAndroid Build Coastguard Workerimport torch.optim 27*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data 28*da0073e9SAndroid Build Coastguard Workerfrom torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall 29*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler import KinetoStepTracker, profile as _profile 30*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler_legacy import profile as _profile_legacy 31*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import ( 32*da0073e9SAndroid Build Coastguard Worker _utils, 33*da0073e9SAndroid Build Coastguard Worker DeviceType, 34*da0073e9SAndroid Build Coastguard Worker kineto_available, 35*da0073e9SAndroid Build Coastguard Worker profile, 36*da0073e9SAndroid Build Coastguard Worker ProfilerAction, 37*da0073e9SAndroid Build Coastguard Worker ProfilerActivity, 38*da0073e9SAndroid Build Coastguard Worker record_function, 39*da0073e9SAndroid Build Coastguard Worker supported_activities, 40*da0073e9SAndroid Build Coastguard Worker) 41*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler._pattern_matcher import ( 42*da0073e9SAndroid Build Coastguard Worker Conv2dBiasFollowedByBatchNorm2dPattern, 43*da0073e9SAndroid Build Coastguard Worker ExtraCUDACopyPattern, 44*da0073e9SAndroid Build Coastguard Worker ForLoopIndexingPattern, 45*da0073e9SAndroid Build Coastguard Worker FP32MatMulPattern, 46*da0073e9SAndroid Build Coastguard Worker GradNotSetToNonePattern, 47*da0073e9SAndroid Build Coastguard Worker MatMulDimInFP16Pattern, 48*da0073e9SAndroid Build Coastguard Worker NamePattern, 49*da0073e9SAndroid Build Coastguard Worker OptimizerSingleTensorPattern, 50*da0073e9SAndroid Build Coastguard Worker Pattern, 51*da0073e9SAndroid Build Coastguard Worker report_all_anti_patterns, 52*da0073e9SAndroid Build Coastguard Worker SynchronizedDataLoaderPattern, 53*da0073e9SAndroid Build Coastguard Worker) 54*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_MULTIGPU 55*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import skipCUDAVersionIn 56*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 57*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 58*da0073e9SAndroid Build Coastguard Worker IS_ARM64, 59*da0073e9SAndroid Build Coastguard Worker IS_JETSON, 60*da0073e9SAndroid Build Coastguard Worker IS_LINUX, 61*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 62*da0073e9SAndroid Build Coastguard Worker parametrize, 63*da0073e9SAndroid Build Coastguard Worker run_tests, 64*da0073e9SAndroid Build Coastguard Worker serialTest, 65*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 66*da0073e9SAndroid Build Coastguard Worker TemporaryDirectoryName, 67*da0073e9SAndroid Build Coastguard Worker TemporaryFileName, 68*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 69*da0073e9SAndroid Build Coastguard Worker TEST_WITH_CROSSREF, 70*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 71*da0073e9SAndroid Build Coastguard Worker TestCase, 72*da0073e9SAndroid Build Coastguard Worker) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker# if tqdm is not shutdown properly, it will leave the monitor thread alive. 76*da0073e9SAndroid Build Coastguard Worker# This causes an issue in the multithreading test because we check all events 77*da0073e9SAndroid Build Coastguard Worker# in that test with their tids. The events that correspond to these lingering 78*da0073e9SAndroid Build Coastguard Worker# threads all have TID of (uint64_t)(-1) which is invalid. 79*da0073e9SAndroid Build Coastguard Worker# The work around is turnning off monitoring thread when tqdm is loaded. 80*da0073e9SAndroid Build Coastguard Worker# Since these are unit tests, it is safe to turn off monitor thread. 81*da0073e9SAndroid Build Coastguard Workertry: 82*da0073e9SAndroid Build Coastguard Worker import tqdm 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker tqdm.tqdm.monitor_interval = 0 85*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 86*da0073e9SAndroid Build Coastguard Worker pass 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Workertry: 89*da0073e9SAndroid Build Coastguard Worker import psutil 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker HAS_PSUTIL = True 92*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError: 93*da0073e9SAndroid Build Coastguard Worker HAS_PSUTIL = False 94*da0073e9SAndroid Build Coastguard Worker psutil = None 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run") 98*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN") 99*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows") 100*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") 101*da0073e9SAndroid Build Coastguard Workerclass TestProfilerCUDA(TestCase): 102*da0073e9SAndroid Build Coastguard Worker @skipCUDAVersionIn([(11, 5)]) # https://github.com/pytorch/pytorch/issues/69023 103*da0073e9SAndroid Build Coastguard Worker def test_mem_leak(self): 104*da0073e9SAndroid Build Coastguard Worker """Checks that there's no memory leak when using profiler with CUDA""" 105*da0073e9SAndroid Build Coastguard Worker t = torch.rand(1, 1).cuda() 106*da0073e9SAndroid Build Coastguard Worker p = psutil.Process() 107*da0073e9SAndroid Build Coastguard Worker last_rss = collections.deque(maxlen=5) 108*da0073e9SAndroid Build Coastguard Worker for outer_idx in range(10): 109*da0073e9SAndroid Build Coastguard Worker with _profile(use_cuda=True): 110*da0073e9SAndroid Build Coastguard Worker for _ in range(1024): 111*da0073e9SAndroid Build Coastguard Worker t = torch.mm(t, t) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker gc.collect() 114*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 115*da0073e9SAndroid Build Coastguard Worker last_rss.append(p.memory_info().rss) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker # with CUDA events leaking the increase in memory was ~7 MB between 118*da0073e9SAndroid Build Coastguard Worker # profiler invocations above 119*da0073e9SAndroid Build Coastguard Worker is_increasing = all( 120*da0073e9SAndroid Build Coastguard Worker last_rss[idx] > last_rss[idx - 1] for idx in range(1, len(last_rss)) 121*da0073e9SAndroid Build Coastguard Worker ) 122*da0073e9SAndroid Build Coastguard Worker max_diff = -1 123*da0073e9SAndroid Build Coastguard Worker for idx in range(1, len(last_rss)): 124*da0073e9SAndroid Build Coastguard Worker max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1]) 125*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 126*da0073e9SAndroid Build Coastguard Worker not (is_increasing and max_diff > 100 * 1024), 127*da0073e9SAndroid Build Coastguard Worker msg=f"memory usage is increasing, {str(last_rss)}", 128*da0073e9SAndroid Build Coastguard Worker ) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker def test_custom_module_input_op_ids(self): 131*da0073e9SAndroid Build Coastguard Worker class MyFunc(torch.autograd.Function): 132*da0073e9SAndroid Build Coastguard Worker @staticmethod 133*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 134*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 135*da0073e9SAndroid Build Coastguard Worker return x 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker @staticmethod 138*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 139*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 140*da0073e9SAndroid Build Coastguard Worker return x 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker def custom_layer(input_ten): 143*da0073e9SAndroid Build Coastguard Worker return MyFunc.apply(input_ten) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker # Only testing that emit_nvtx runs when 146*da0073e9SAndroid Build Coastguard Worker # record_shapes option is enabled. 147*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.emit_nvtx(record_shapes=True) as prof: 148*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 149*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10, requires_grad=True) 150*da0073e9SAndroid Build Coastguard Worker z = x + y 151*da0073e9SAndroid Build Coastguard Worker s = custom_layer(z) 152*da0073e9SAndroid Build Coastguard Worker q = s.sum() 153*da0073e9SAndroid Build Coastguard Worker q.backward() 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") 156*da0073e9SAndroid Build Coastguard Worker def test_cudagraph_profiling_workaround(self): 157*da0073e9SAndroid Build Coastguard Worker import subprocess 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker # repro taken from #75504 160*da0073e9SAndroid Build Coastguard Worker # Launch in a separate process to catch hanging/illegal memory errors 161*da0073e9SAndroid Build Coastguard Worker # and to make sure CUPTI isn't already initialized. 162*da0073e9SAndroid Build Coastguard Worker p = subprocess.check_call( 163*da0073e9SAndroid Build Coastguard Worker [ 164*da0073e9SAndroid Build Coastguard Worker sys.executable, 165*da0073e9SAndroid Build Coastguard Worker "-c", 166*da0073e9SAndroid Build Coastguard Worker """ 167*da0073e9SAndroid Build Coastguard Workerimport os 168*da0073e9SAndroid Build Coastguard Workerimport torch 169*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import ProfilerActivity, profile 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Workerdef add_one(in_: torch.Tensor): 172*da0073e9SAndroid Build Coastguard Worker return in_ + 1 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Workersample_arg = torch.zeros(10, device="cuda").requires_grad_(True) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker# add this before cuda graphs are created 177*da0073e9SAndroid Build Coastguard Workertorch.profiler._utils._init_for_cuda_graphs() 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Workeradd_one_graphed = torch.cuda.graphs.make_graphed_callables(add_one, sample_args=(sample_arg,)) 180*da0073e9SAndroid Build Coastguard Workerzeros = torch.zeros(10, device="cuda") 181*da0073e9SAndroid Build Coastguard Workerout = add_one_graphed(zeros) 182*da0073e9SAndroid Build Coastguard Workerassert out[0] == 1 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Workerwith profile(activities=[ProfilerActivity.CPU]): 185*da0073e9SAndroid Build Coastguard Worker add_one_graphed(zeros) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Workerwith profile(activities=[ProfilerActivity.CUDA]): 188*da0073e9SAndroid Build Coastguard Worker add_one_graphed(zeros) 189*da0073e9SAndroid Build Coastguard Worker""", 190*da0073e9SAndroid Build Coastguard Worker ], 191*da0073e9SAndroid Build Coastguard Worker universal_newlines=True, 192*da0073e9SAndroid Build Coastguard Worker timeout=60, 193*da0073e9SAndroid Build Coastguard Worker ) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker # ^ this will throw an exception if the script fails. 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required") 199*da0073e9SAndroid Build Coastguard Workerclass TestProfilerITT(TestCase): 200*da0073e9SAndroid Build Coastguard Worker def test_custom_module_input_op_ids(self): 201*da0073e9SAndroid Build Coastguard Worker class MyFunc(torch.autograd.Function): 202*da0073e9SAndroid Build Coastguard Worker @staticmethod 203*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 204*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 205*da0073e9SAndroid Build Coastguard Worker return x 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker @staticmethod 208*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 209*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 210*da0073e9SAndroid Build Coastguard Worker return x 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker def custom_layer(input_ten): 213*da0073e9SAndroid Build Coastguard Worker return MyFunc.apply(input_ten) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker # Only testing that emit_itt runs when 216*da0073e9SAndroid Build Coastguard Worker # record_shapes option is enabled. 217*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.emit_itt(record_shapes=True) as prof: 218*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 219*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10, requires_grad=True) 220*da0073e9SAndroid Build Coastguard Worker z = x + y 221*da0073e9SAndroid Build Coastguard Worker s = custom_layer(z) 222*da0073e9SAndroid Build Coastguard Worker q = s.sum() 223*da0073e9SAndroid Build Coastguard Worker q.backward() 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker@instantiate_parametrized_tests 227*da0073e9SAndroid Build Coastguard Workerclass TestProfiler(TestCase): 228*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 229*da0073e9SAndroid Build Coastguard Worker TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." 230*da0073e9SAndroid Build Coastguard Worker ) 231*da0073e9SAndroid Build Coastguard Worker def test_source(self): 232*da0073e9SAndroid Build Coastguard Worker """Checks that source code attribution works for eager, TS and autograd mode""" 233*da0073e9SAndroid Build Coastguard Worker # avoid automatic inlining 234*da0073e9SAndroid Build Coastguard Worker prev_opt = torch._C._get_graph_executor_optimize() 235*da0073e9SAndroid Build Coastguard Worker torch._C._set_graph_executor_optimize(False) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 238*da0073e9SAndroid Build Coastguard Worker def ts_method_2(x, y): 239*da0073e9SAndroid Build Coastguard Worker return torch.matmul(x, y) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 242*da0073e9SAndroid Build Coastguard Worker def ts_method_1(x, y, z): 243*da0073e9SAndroid Build Coastguard Worker a = x + z 244*da0073e9SAndroid Build Coastguard Worker w = ts_method_2(x, y) + a 245*da0073e9SAndroid Build Coastguard Worker return w.sum() 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker class DummyModule(nn.Module): 248*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 249*da0073e9SAndroid Build Coastguard Worker super().__init__() 250*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 251*da0073e9SAndroid Build Coastguard Worker 3, 2, kernel_size=1, stride=2, padding=3, bias=False 252*da0073e9SAndroid Build Coastguard Worker ) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 255*da0073e9SAndroid Build Coastguard Worker return self.conv(x) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker mod = DummyModule() 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker def call_module(x): 260*da0073e9SAndroid Build Coastguard Worker return mod(x) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker with _profile( 263*da0073e9SAndroid Build Coastguard Worker with_stack=True, 264*da0073e9SAndroid Build Coastguard Worker use_kineto=kineto_available(), 265*da0073e9SAndroid Build Coastguard Worker experimental_config=_ExperimentalConfig(verbose=True), 266*da0073e9SAndroid Build Coastguard Worker ) as p: 267*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 268*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10, requires_grad=True) 269*da0073e9SAndroid Build Coastguard Worker z = x + y 270*da0073e9SAndroid Build Coastguard Worker w = ts_method_1(x, y, z) 271*da0073e9SAndroid Build Coastguard Worker v = 2 * w 272*da0073e9SAndroid Build Coastguard Worker v.backward() 273*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 2, 2, requires_grad=True) 274*da0073e9SAndroid Build Coastguard Worker b = call_module(a) 275*da0073e9SAndroid Build Coastguard Worker c = b.sum() 276*da0073e9SAndroid Build Coastguard Worker c.backward() 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker for e in p.function_events: 279*da0073e9SAndroid Build Coastguard Worker if "aten::add" in e.name or "AddBackward" in e.name: 280*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any("test_profiler" in entry for entry in e.stack)) 281*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 282*da0073e9SAndroid Build Coastguard Worker any( 283*da0073e9SAndroid Build Coastguard Worker ( 284*da0073e9SAndroid Build Coastguard Worker "test_source" in entry 285*da0073e9SAndroid Build Coastguard Worker or "ts_method_1" in entry 286*da0073e9SAndroid Build Coastguard Worker or "ts_method_2" in entry 287*da0073e9SAndroid Build Coastguard Worker ) 288*da0073e9SAndroid Build Coastguard Worker for entry in e.stack 289*da0073e9SAndroid Build Coastguard Worker ) 290*da0073e9SAndroid Build Coastguard Worker ) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker # TODO: https://github.com/pytorch/kineto/issues/617 293*da0073e9SAndroid Build Coastguard Worker if kineto_available() and not IS_WINDOWS: 294*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 295*da0073e9SAndroid Build Coastguard Worker p.export_chrome_trace(fname) 296*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 297*da0073e9SAndroid Build Coastguard Worker events = json.load(f)["traceEvents"] 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker def extract(pattern: str): 300*da0073e9SAndroid Build Coastguard Worker matches = [e for e in events if re.search(pattern, e["name"])] 301*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 302*da0073e9SAndroid Build Coastguard Worker len(matches), 1, repr([e["name"] for e in matches]) 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker return matches[0] 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker module_event = extract(r"DummyModule_0") 307*da0073e9SAndroid Build Coastguard Worker wrapper_event = extract(r"call_module") 308*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 309*da0073e9SAndroid Build Coastguard Worker module_event["args"]["Python parent id"], 310*da0073e9SAndroid Build Coastguard Worker wrapper_event["args"]["Python id"], 311*da0073e9SAndroid Build Coastguard Worker ) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker torch._C._set_graph_executor_optimize(prev_opt) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker @parametrize( 316*da0073e9SAndroid Build Coastguard Worker "name,thread_spec", 317*da0073e9SAndroid Build Coastguard Worker { 318*da0073e9SAndroid Build Coastguard Worker "basic": ((False, False),), 319*da0073e9SAndroid Build Coastguard Worker "multiple_preexisting": ((False, False),) * 2, 320*da0073e9SAndroid Build Coastguard Worker "open_in_scope": ((True, False),), 321*da0073e9SAndroid Build Coastguard Worker "close_in_scope": ((False, True),), 322*da0073e9SAndroid Build Coastguard Worker "complex": ( 323*da0073e9SAndroid Build Coastguard Worker # Large number of background threads 324*da0073e9SAndroid Build Coastguard Worker (False, False), 325*da0073e9SAndroid Build Coastguard Worker (False, False), 326*da0073e9SAndroid Build Coastguard Worker (False, False), 327*da0073e9SAndroid Build Coastguard Worker (False, False), 328*da0073e9SAndroid Build Coastguard Worker # some of which finish during profiling 329*da0073e9SAndroid Build Coastguard Worker (False, True), 330*da0073e9SAndroid Build Coastguard Worker (False, True), 331*da0073e9SAndroid Build Coastguard Worker # And the profiled section is also multithreaded 332*da0073e9SAndroid Build Coastguard Worker (True, False), 333*da0073e9SAndroid Build Coastguard Worker (True, True), 334*da0073e9SAndroid Build Coastguard Worker ), 335*da0073e9SAndroid Build Coastguard Worker }.items(), 336*da0073e9SAndroid Build Coastguard Worker name_fn=lambda name, thread_spec: name, 337*da0073e9SAndroid Build Coastguard Worker ) 338*da0073e9SAndroid Build Coastguard Worker @serialTest() 339*da0073e9SAndroid Build Coastguard Worker @parametrize("work_in_main_thread", [True, False]) 340*da0073e9SAndroid Build Coastguard Worker def test_source_multithreaded(self, name, thread_spec, work_in_main_thread): 341*da0073e9SAndroid Build Coastguard Worker """Test various threading configurations. 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker `thread_spec` is a Tuple[Tuple[bool, bool], ...] where each pair is a 344*da0073e9SAndroid Build Coastguard Worker thread. The first bool indicates if the thread should be started under 345*da0073e9SAndroid Build Coastguard Worker the profiler context and the second is if it should be joined under the 346*da0073e9SAndroid Build Coastguard Worker profiler context. 347*da0073e9SAndroid Build Coastguard Worker """ 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker timeout = 15 350*da0073e9SAndroid Build Coastguard Worker num_threads = len(thread_spec) + 1 # Main thread 351*da0073e9SAndroid Build Coastguard Worker start_barrier = threading.Barrier(num_threads, timeout=timeout) 352*da0073e9SAndroid Build Coastguard Worker end_barrier = threading.Barrier(num_threads, timeout=timeout) 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker class Task(threading.Thread): 355*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 356*da0073e9SAndroid Build Coastguard Worker self._end_gate = threading.Event() 357*da0073e9SAndroid Build Coastguard Worker super().__init__(daemon=True) 358*da0073e9SAndroid Build Coastguard Worker self.start() 359*da0073e9SAndroid Build Coastguard Worker self.finished = False 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker def run(self): 362*da0073e9SAndroid Build Coastguard Worker self._run(self._end_gate) 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker def release(self): 365*da0073e9SAndroid Build Coastguard Worker self._end_gate.set() 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker @staticmethod 368*da0073e9SAndroid Build Coastguard Worker def _run(end_gate=None): 369*da0073e9SAndroid Build Coastguard Worker def known_preexisting_function(): 370*da0073e9SAndroid Build Coastguard Worker start_barrier.wait() 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker # Fixed point that we can use to test capture of functions 373*da0073e9SAndroid Build Coastguard Worker # which are already running when profiling is enabled. 374*da0073e9SAndroid Build Coastguard Worker known_preexisting_function() 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 377*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 378*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 379*da0073e9SAndroid Build Coastguard Worker ) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker def invoked_during_run(): 382*da0073e9SAndroid Build Coastguard Worker pass 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker invoked_during_run() 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker _ = model(torch.rand(4, 10)) 387*da0073e9SAndroid Build Coastguard Worker end_barrier.wait() 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker if end_gate is not None: 390*da0073e9SAndroid Build Coastguard Worker end_gate.wait(timeout=timeout) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker threads = {} 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker def add_threads(context: bool): 395*da0073e9SAndroid Build Coastguard Worker for idx, (start_under_profiler, _) in enumerate(thread_spec): 396*da0073e9SAndroid Build Coastguard Worker if start_under_profiler == context: 397*da0073e9SAndroid Build Coastguard Worker assert idx not in threads 398*da0073e9SAndroid Build Coastguard Worker threads[idx] = Task() 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker def join_threads(context: bool): 401*da0073e9SAndroid Build Coastguard Worker for idx, (_, end_under_profiler) in enumerate(thread_spec): 402*da0073e9SAndroid Build Coastguard Worker if end_under_profiler == context: 403*da0073e9SAndroid Build Coastguard Worker threads[idx].release() 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker for idx, (_, end_under_profiler) in enumerate(thread_spec): 406*da0073e9SAndroid Build Coastguard Worker t = threads[idx] 407*da0073e9SAndroid Build Coastguard Worker if end_under_profiler == context: 408*da0073e9SAndroid Build Coastguard Worker t.join(timeout=timeout) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker try: 411*da0073e9SAndroid Build Coastguard Worker add_threads(False) 412*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(with_stack=True) as prof: 413*da0073e9SAndroid Build Coastguard Worker # Threads added while the profiler are running will not be observed 414*da0073e9SAndroid Build Coastguard Worker # since there is no way to hook into Python's thread start call to 415*da0073e9SAndroid Build Coastguard Worker # register the observer. These are here purely to verify safety. 416*da0073e9SAndroid Build Coastguard Worker add_threads(True) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker if work_in_main_thread: 419*da0073e9SAndroid Build Coastguard Worker Task._run() 420*da0073e9SAndroid Build Coastguard Worker else: 421*da0073e9SAndroid Build Coastguard Worker start_barrier.wait() 422*da0073e9SAndroid Build Coastguard Worker end_barrier.wait() 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker join_threads(True) 425*da0073e9SAndroid Build Coastguard Worker join_threads(False) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker finally: 428*da0073e9SAndroid Build Coastguard Worker # It is very important that we clean up everything because the 429*da0073e9SAndroid Build Coastguard Worker # Python tracer will detect ALL active threads. (Even orphans from 430*da0073e9SAndroid Build Coastguard Worker # prior failed tests.) If we don't clean up properly we can 431*da0073e9SAndroid Build Coastguard Worker # contaminate subsequent tests. 432*da0073e9SAndroid Build Coastguard Worker start_barrier.abort() 433*da0073e9SAndroid Build Coastguard Worker end_barrier.abort() 434*da0073e9SAndroid Build Coastguard Worker for t in threads.values(): 435*da0073e9SAndroid Build Coastguard Worker t.release() 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker for t in threads.values(): 438*da0073e9SAndroid Build Coastguard Worker t.join(timeout=timeout) 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker for t in threads.values(): 441*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t.is_alive()) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker roots = prof.profiler.kineto_results.experimental_event_tree() 444*da0073e9SAndroid Build Coastguard Worker nodes = [ 445*da0073e9SAndroid Build Coastguard Worker node 446*da0073e9SAndroid Build Coastguard Worker for node in _utils.traverse_dfs(roots) 447*da0073e9SAndroid Build Coastguard Worker if isinstance(node.extra_fields, _ExtraFields_PyCall) 448*da0073e9SAndroid Build Coastguard Worker ] 449*da0073e9SAndroid Build Coastguard Worker tid_counts = collections.Counter([node.start_tid for node in nodes]) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker prior_threads = sum( 452*da0073e9SAndroid Build Coastguard Worker not start_under_profiler for start_under_profiler, _ in thread_spec 453*da0073e9SAndroid Build Coastguard Worker ) 454*da0073e9SAndroid Build Coastguard Worker expected_threads = prior_threads + 1 455*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 456*da0073e9SAndroid Build Coastguard Worker len(tid_counts), expected_threads, f"{expected_threads}, {tid_counts}" 457*da0073e9SAndroid Build Coastguard Worker ) 458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(nodes), sum(tid_counts.values())) 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker # Profiler uses uint64_t max as a placeholder until TID can be determined. 461*da0073e9SAndroid Build Coastguard Worker no_tid = 2**64 - 1 462*da0073e9SAndroid Build Coastguard Worker self.assertFalse(no_tid in tid_counts) 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker worker_threads = prior_threads + (1 if work_in_main_thread else 0) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker observed_preexisting = [ 467*da0073e9SAndroid Build Coastguard Worker node.start_tid 468*da0073e9SAndroid Build Coastguard Worker for node in nodes 469*da0073e9SAndroid Build Coastguard Worker if "known_preexisting_function" in node.name 470*da0073e9SAndroid Build Coastguard Worker ] 471*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(observed_preexisting), worker_threads) 472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(observed_preexisting), len(set(observed_preexisting))) 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker observed_during_run = [ 475*da0073e9SAndroid Build Coastguard Worker node.start_tid for node in nodes if "invoked_during_run" in node.name 476*da0073e9SAndroid Build Coastguard Worker ] 477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(observed_during_run), worker_threads) 478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(observed_during_run), len(set(observed_during_run))) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker def payload(self, use_cuda=False): 481*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 482*da0073e9SAndroid Build Coastguard Worker if use_cuda: 483*da0073e9SAndroid Build Coastguard Worker x = x.cuda() 484*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10) 485*da0073e9SAndroid Build Coastguard Worker if use_cuda: 486*da0073e9SAndroid Build Coastguard Worker y = y.cuda() 487*da0073e9SAndroid Build Coastguard Worker z = torch.mm(x, y) 488*da0073e9SAndroid Build Coastguard Worker z = z + y 489*da0073e9SAndroid Build Coastguard Worker if use_cuda: 490*da0073e9SAndroid Build Coastguard Worker z = z.cpu() 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker def _check_stats(self, profiler_stats): 493*da0073e9SAndroid Build Coastguard Worker self.assertGreater(profiler_stats.profiling_window_duration_sec, 0) 494*da0073e9SAndroid Build Coastguard Worker self.assertGreater(profiler_stats.number_of_events, 0) 495*da0073e9SAndroid Build Coastguard Worker self.assertGreater(profiler_stats.profiler_prepare_call_duration_us, 0) 496*da0073e9SAndroid Build Coastguard Worker self.assertGreater(profiler_stats.profiler_enable_call_duration_us, 0) 497*da0073e9SAndroid Build Coastguard Worker self.assertGreater(profiler_stats.profiler_disable_call_duration_us, 0) 498*da0073e9SAndroid Build Coastguard Worker self.assertGreater(profiler_stats.parse_kineto_call_duration_us, 0) 499*da0073e9SAndroid Build Coastguard Worker self.assertGreater( 500*da0073e9SAndroid Build Coastguard Worker profiler_stats.function_events_build_tree_call_duration_us, 0 501*da0073e9SAndroid Build Coastguard Worker ) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not kineto_available(), "Kineto is required") 504*da0073e9SAndroid Build Coastguard Worker def test_kineto(self): 505*da0073e9SAndroid Build Coastguard Worker use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 506*da0073e9SAndroid Build Coastguard Worker with _profile(use_cuda=use_cuda, use_kineto=True): 507*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker # rerun to avoid initial start overhead 510*da0073e9SAndroid Build Coastguard Worker with _profile(use_cuda=use_cuda, use_kineto=True) as p: 511*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker self.assertTrue("aten::mm" in str(p)) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker output = p.key_averages().table( 516*da0073e9SAndroid Build Coastguard Worker sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total", 517*da0073e9SAndroid Build Coastguard Worker row_limit=-1, 518*da0073e9SAndroid Build Coastguard Worker ) 519*da0073e9SAndroid Build Coastguard Worker # print(output) 520*da0073e9SAndroid Build Coastguard Worker found_gemm = False 521*da0073e9SAndroid Build Coastguard Worker found_memcpy = False 522*da0073e9SAndroid Build Coastguard Worker found_mm = False 523*da0073e9SAndroid Build Coastguard Worker for e in p.function_events: 524*da0073e9SAndroid Build Coastguard Worker if "aten::mm" in e.name: 525*da0073e9SAndroid Build Coastguard Worker found_mm = True 526*da0073e9SAndroid Build Coastguard Worker if "gemm" in e.name.lower() or "Cijk" in e.name: 527*da0073e9SAndroid Build Coastguard Worker found_gemm = True 528*da0073e9SAndroid Build Coastguard Worker if "memcpy" in e.name.lower(): 529*da0073e9SAndroid Build Coastguard Worker found_memcpy = True 530*da0073e9SAndroid Build Coastguard Worker if use_cuda: 531*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_gemm) 532*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_memcpy) 533*da0073e9SAndroid Build Coastguard Worker else: 534*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_mm) 535*da0073e9SAndroid Build Coastguard Worker self._check_stats(p._stats) 536*da0073e9SAndroid Build Coastguard Worker # p.export_chrome_trace("/tmp/test_trace.json") 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not kineto_available(), "Kineto is required") 539*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MULTIGPU, "Multiple GPUs needed") 540*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") 541*da0073e9SAndroid Build Coastguard Worker def test_kineto_multigpu(self): 542*da0073e9SAndroid Build Coastguard Worker with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: 543*da0073e9SAndroid Build Coastguard Worker for gpu_id in [0, 1]: 544*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10).cuda(gpu_id) 545*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10).cuda(gpu_id) 546*da0073e9SAndroid Build Coastguard Worker z = x.matmul(y) 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker found_gemm_0 = False 549*da0073e9SAndroid Build Coastguard Worker found_gemm_1 = False 550*da0073e9SAndroid Build Coastguard Worker found_cuda = False 551*da0073e9SAndroid Build Coastguard Worker for evt in prof.events(): 552*da0073e9SAndroid Build Coastguard Worker if "gemm" in evt.name.lower() and evt.device_type == DeviceType.CUDA: 553*da0073e9SAndroid Build Coastguard Worker if evt.device_index == 0: 554*da0073e9SAndroid Build Coastguard Worker found_gemm_0 = True 555*da0073e9SAndroid Build Coastguard Worker elif evt.device_index == 1: 556*da0073e9SAndroid Build Coastguard Worker found_gemm_1 = True 557*da0073e9SAndroid Build Coastguard Worker if "cuda" in evt.name.lower() and evt.device_type == DeviceType.CPU: 558*da0073e9SAndroid Build Coastguard Worker found_cuda = True 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_gemm_0) 561*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_gemm_1) 562*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_cuda) 563*da0073e9SAndroid Build Coastguard Worker self._check_stats(prof._stats()) 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker def test_memory_profiler(self): 566*da0073e9SAndroid Build Coastguard Worker def run_profiler(tensor_creation_fn): 567*da0073e9SAndroid Build Coastguard Worker # collecting allocs / deallocs 568*da0073e9SAndroid Build Coastguard Worker with _profile( 569*da0073e9SAndroid Build Coastguard Worker profile_memory=True, 570*da0073e9SAndroid Build Coastguard Worker record_shapes=True, 571*da0073e9SAndroid Build Coastguard Worker use_kineto=kineto_available(), 572*da0073e9SAndroid Build Coastguard Worker ) as prof: 573*da0073e9SAndroid Build Coastguard Worker x = None 574*da0073e9SAndroid Build Coastguard Worker with record_function("test_user_scope_alloc"): 575*da0073e9SAndroid Build Coastguard Worker x = tensor_creation_fn() 576*da0073e9SAndroid Build Coastguard Worker with record_function("test_user_scope_dealloc"): 577*da0073e9SAndroid Build Coastguard Worker del x 578*da0073e9SAndroid Build Coastguard Worker return prof.key_averages(group_by_input_shape=True) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker def check_metrics(stats, metric, allocs=None, deallocs=None): 581*da0073e9SAndroid Build Coastguard Worker stat_metrics = {} 582*da0073e9SAndroid Build Coastguard Worker # print(stats) 583*da0073e9SAndroid Build Coastguard Worker for stat in stats: 584*da0073e9SAndroid Build Coastguard Worker stat_metrics[stat.key] = getattr(stat, metric) 585*da0073e9SAndroid Build Coastguard Worker # print(stat_metrics) 586*da0073e9SAndroid Build Coastguard Worker if allocs is not None: 587*da0073e9SAndroid Build Coastguard Worker for alloc_fn in allocs: 588*da0073e9SAndroid Build Coastguard Worker self.assertTrue(alloc_fn in stat_metrics) 589*da0073e9SAndroid Build Coastguard Worker self.assertGreater( 590*da0073e9SAndroid Build Coastguard Worker stat_metrics[alloc_fn], 0, f"alloc_fn = {alloc_fn}" 591*da0073e9SAndroid Build Coastguard Worker ) 592*da0073e9SAndroid Build Coastguard Worker if deallocs is not None: 593*da0073e9SAndroid Build Coastguard Worker for dealloc_fn in deallocs: 594*da0073e9SAndroid Build Coastguard Worker self.assertTrue(dealloc_fn in stat_metrics) 595*da0073e9SAndroid Build Coastguard Worker self.assertLess( 596*da0073e9SAndroid Build Coastguard Worker stat_metrics[dealloc_fn], 0, f"alloc_fn = {dealloc_fn}" 597*da0073e9SAndroid Build Coastguard Worker ) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker def create_cpu_tensor(): 600*da0073e9SAndroid Build Coastguard Worker return torch.rand(10, 10) 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker def create_cuda_tensor(): 603*da0073e9SAndroid Build Coastguard Worker return torch.rand(10, 10).cuda() 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker def create_mkldnn_tensor(): 606*da0073e9SAndroid Build Coastguard Worker return torch.rand(10, 10, dtype=torch.float32).to_mkldnn() 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker stats = run_profiler(create_cpu_tensor) 609*da0073e9SAndroid Build Coastguard Worker check_metrics( 610*da0073e9SAndroid Build Coastguard Worker stats, 611*da0073e9SAndroid Build Coastguard Worker "cpu_memory_usage", 612*da0073e9SAndroid Build Coastguard Worker allocs=[ 613*da0073e9SAndroid Build Coastguard Worker "aten::empty", 614*da0073e9SAndroid Build Coastguard Worker "aten::rand", 615*da0073e9SAndroid Build Coastguard Worker "test_user_scope_alloc", 616*da0073e9SAndroid Build Coastguard Worker ], 617*da0073e9SAndroid Build Coastguard Worker deallocs=[ 618*da0073e9SAndroid Build Coastguard Worker "test_user_scope_dealloc", 619*da0073e9SAndroid Build Coastguard Worker ], 620*da0073e9SAndroid Build Coastguard Worker ) 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker if kineto_available(): 623*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 624*da0073e9SAndroid Build Coastguard Worker with profile(profile_memory=True) as prof: 625*da0073e9SAndroid Build Coastguard Worker x = None 626*da0073e9SAndroid Build Coastguard Worker with record_function("test_user_scope_alloc"): 627*da0073e9SAndroid Build Coastguard Worker x = create_cpu_tensor() 628*da0073e9SAndroid Build Coastguard Worker with record_function("test_user_scope_dealloc"): 629*da0073e9SAndroid Build Coastguard Worker del x 630*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 631*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 632*da0073e9SAndroid Build Coastguard Worker trace = json.load(f) 633*da0073e9SAndroid Build Coastguard Worker assert "traceEvents" in trace 634*da0073e9SAndroid Build Coastguard Worker events = trace["traceEvents"] 635*da0073e9SAndroid Build Coastguard Worker found_memory_events = False 636*da0073e9SAndroid Build Coastguard Worker for evt in events: 637*da0073e9SAndroid Build Coastguard Worker assert "name" in evt 638*da0073e9SAndroid Build Coastguard Worker if evt["name"] == "[memory]": 639*da0073e9SAndroid Build Coastguard Worker found_memory_events = True 640*da0073e9SAndroid Build Coastguard Worker assert "args" in evt 641*da0073e9SAndroid Build Coastguard Worker assert "Addr" in evt["args"] 642*da0073e9SAndroid Build Coastguard Worker assert "Device Type" in evt["args"] 643*da0073e9SAndroid Build Coastguard Worker assert "Device Id" in evt["args"] 644*da0073e9SAndroid Build Coastguard Worker assert "Bytes" in evt["args"] 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker # Memory should be an instantaneous event. 647*da0073e9SAndroid Build Coastguard Worker assert "dur" not in evt["args"] 648*da0073e9SAndroid Build Coastguard Worker assert "cat" not in evt["args"] 649*da0073e9SAndroid Build Coastguard Worker assert found_memory_events 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 652*da0073e9SAndroid Build Coastguard Worker create_cuda_tensor() 653*da0073e9SAndroid Build Coastguard Worker stats = run_profiler(create_cuda_tensor) 654*da0073e9SAndroid Build Coastguard Worker check_metrics( 655*da0073e9SAndroid Build Coastguard Worker stats, 656*da0073e9SAndroid Build Coastguard Worker "device_memory_usage", 657*da0073e9SAndroid Build Coastguard Worker allocs=[ 658*da0073e9SAndroid Build Coastguard Worker "test_user_scope_alloc", 659*da0073e9SAndroid Build Coastguard Worker "aten::to", 660*da0073e9SAndroid Build Coastguard Worker "aten::empty_strided", 661*da0073e9SAndroid Build Coastguard Worker ], 662*da0073e9SAndroid Build Coastguard Worker deallocs=[ 663*da0073e9SAndroid Build Coastguard Worker "test_user_scope_dealloc", 664*da0073e9SAndroid Build Coastguard Worker ], 665*da0073e9SAndroid Build Coastguard Worker ) 666*da0073e9SAndroid Build Coastguard Worker check_metrics( 667*da0073e9SAndroid Build Coastguard Worker stats, 668*da0073e9SAndroid Build Coastguard Worker "cpu_memory_usage", 669*da0073e9SAndroid Build Coastguard Worker allocs=[ 670*da0073e9SAndroid Build Coastguard Worker "aten::rand", 671*da0073e9SAndroid Build Coastguard Worker "aten::empty", 672*da0073e9SAndroid Build Coastguard Worker ], 673*da0073e9SAndroid Build Coastguard Worker ) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker if torch.backends.mkldnn.is_available(): 676*da0073e9SAndroid Build Coastguard Worker create_mkldnn_tensor() 677*da0073e9SAndroid Build Coastguard Worker stats = run_profiler(create_mkldnn_tensor) 678*da0073e9SAndroid Build Coastguard Worker check_metrics( 679*da0073e9SAndroid Build Coastguard Worker stats, 680*da0073e9SAndroid Build Coastguard Worker "cpu_memory_usage", 681*da0073e9SAndroid Build Coastguard Worker allocs=[ 682*da0073e9SAndroid Build Coastguard Worker "test_user_scope_alloc", 683*da0073e9SAndroid Build Coastguard Worker "aten::rand", 684*da0073e9SAndroid Build Coastguard Worker "aten::empty", 685*da0073e9SAndroid Build Coastguard Worker "aten::to_mkldnn", 686*da0073e9SAndroid Build Coastguard Worker ], 687*da0073e9SAndroid Build Coastguard Worker deallocs=[ 688*da0073e9SAndroid Build Coastguard Worker "test_user_scope_dealloc", 689*da0073e9SAndroid Build Coastguard Worker ], 690*da0073e9SAndroid Build Coastguard Worker ) 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker # check top-level memory events 693*da0073e9SAndroid Build Coastguard Worker with _profile(profile_memory=True, use_kineto=kineto_available()) as prof: 694*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, 10) 695*da0073e9SAndroid Build Coastguard Worker del x 696*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 697*da0073e9SAndroid Build Coastguard Worker y = torch.rand(10, 10).cuda() 698*da0073e9SAndroid Build Coastguard Worker del y 699*da0073e9SAndroid Build Coastguard Worker gc.collect() 700*da0073e9SAndroid Build Coastguard Worker stats = prof.key_averages(group_by_input_shape=True) 701*da0073e9SAndroid Build Coastguard Worker check_metrics( 702*da0073e9SAndroid Build Coastguard Worker stats, 703*da0073e9SAndroid Build Coastguard Worker "cpu_memory_usage", 704*da0073e9SAndroid Build Coastguard Worker allocs=["aten::rand", "aten::empty"], 705*da0073e9SAndroid Build Coastguard Worker deallocs=["[memory]"], 706*da0073e9SAndroid Build Coastguard Worker ) 707*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 708*da0073e9SAndroid Build Coastguard Worker check_metrics(stats, "device_memory_usage", deallocs=["[memory]"]) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 711*da0073e9SAndroid Build Coastguard Worker IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared" 712*da0073e9SAndroid Build Coastguard Worker ) 713*da0073e9SAndroid Build Coastguard Worker def test_oom_tracing(self): 714*da0073e9SAndroid Build Coastguard Worker def run_profiler(tensor_creation_fn): 715*da0073e9SAndroid Build Coastguard Worker with _profile(profile_memory=True, record_shapes=True) as prof: 716*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".*[tT]ried to allocate.*"): 717*da0073e9SAndroid Build Coastguard Worker x = tensor_creation_fn() 718*da0073e9SAndroid Build Coastguard Worker return prof 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker def create_cuda_tensor_oom(): 721*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:0") 722*da0073e9SAndroid Build Coastguard Worker return torch.empty( 723*da0073e9SAndroid Build Coastguard Worker 1024, 1024, 1024, 1024, dtype=torch.float32, device=device 724*da0073e9SAndroid Build Coastguard Worker ) 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker def check_trace(fname): 727*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 728*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 729*da0073e9SAndroid Build Coastguard Worker trace = json.load(f) 730*da0073e9SAndroid Build Coastguard Worker self.assertTrue("traceEvents" in trace) 731*da0073e9SAndroid Build Coastguard Worker events = trace["traceEvents"] 732*da0073e9SAndroid Build Coastguard Worker found_out_of_memory_events = False 733*da0073e9SAndroid Build Coastguard Worker for evt in events: 734*da0073e9SAndroid Build Coastguard Worker self.assertTrue("name" in evt) 735*da0073e9SAndroid Build Coastguard Worker if evt["name"] == "[OutOfMemory]": 736*da0073e9SAndroid Build Coastguard Worker found_out_of_memory_events = True 737*da0073e9SAndroid Build Coastguard Worker self.assertTrue("args" in evt) 738*da0073e9SAndroid Build Coastguard Worker self.assertTrue("Device Type" in evt["args"]) 739*da0073e9SAndroid Build Coastguard Worker self.assertTrue("Device Id" in evt["args"]) 740*da0073e9SAndroid Build Coastguard Worker self.assertTrue("Bytes" in evt["args"]) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker # Memory should be an instantaneous event. 743*da0073e9SAndroid Build Coastguard Worker self.assertTrue("dur" not in evt["args"]) 744*da0073e9SAndroid Build Coastguard Worker self.assertTrue("cat" not in evt["args"]) 745*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_out_of_memory_events) 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 748*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 749*da0073e9SAndroid Build Coastguard Worker prof = run_profiler(create_cuda_tensor_oom) 750*da0073e9SAndroid Build Coastguard Worker check_trace(fname) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not kineto_available(), "Kineto is required") 753*da0073e9SAndroid Build Coastguard Worker def test_module_hierarchy(self): 754*da0073e9SAndroid Build Coastguard Worker class A(nn.Module): 755*da0073e9SAndroid Build Coastguard Worker def my_new_method(self, x): 756*da0073e9SAndroid Build Coastguard Worker return x * 3 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker def forward_impl_(self, x, y): 759*da0073e9SAndroid Build Coastguard Worker return self.my_new_method(x) + y 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 762*da0073e9SAndroid Build Coastguard Worker y = y - 2 763*da0073e9SAndroid Build Coastguard Worker return self.forward_impl_(x, y) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker class B(nn.Module): 766*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 767*da0073e9SAndroid Build Coastguard Worker return x + 2 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker class C(nn.Module): 770*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 771*da0073e9SAndroid Build Coastguard Worker super().__init__() 772*da0073e9SAndroid Build Coastguard Worker self.A0 = A() 773*da0073e9SAndroid Build Coastguard Worker self.B0 = B() 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker def call_b(self, x): 776*da0073e9SAndroid Build Coastguard Worker return self.B0.forward(x) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 779*da0073e9SAndroid Build Coastguard Worker return self.A0.forward(x, y) + self.call_b(x) 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker model = C() 782*da0073e9SAndroid Build Coastguard Worker model = torch.jit.script(model) 783*da0073e9SAndroid Build Coastguard Worker input_a = torch.rand(128, 128) 784*da0073e9SAndroid Build Coastguard Worker input_b = torch.rand(128, 128) 785*da0073e9SAndroid Build Coastguard Worker op_to_module_hierarchy = {} 786*da0073e9SAndroid Build Coastguard Worker op_to_module_hierarchy["aten::sub"] = ["TOP(C)::forward.A0(A)::forward."] 787*da0073e9SAndroid Build Coastguard Worker op_to_module_hierarchy["aten::mul"] = [ 788*da0073e9SAndroid Build Coastguard Worker "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.SELF(A)::my_new_method." 789*da0073e9SAndroid Build Coastguard Worker ] 790*da0073e9SAndroid Build Coastguard Worker op_to_module_hierarchy["aten::add"] = [ 791*da0073e9SAndroid Build Coastguard Worker "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.", 792*da0073e9SAndroid Build Coastguard Worker "TOP(C)::forward.SELF(C)::call_b.B0(B)::forward.", 793*da0073e9SAndroid Build Coastguard Worker "TOP(C)::forward.", 794*da0073e9SAndroid Build Coastguard Worker ] 795*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 796*da0073e9SAndroid Build Coastguard Worker with profile( 797*da0073e9SAndroid Build Coastguard Worker activities=[torch.profiler.ProfilerActivity.CPU], 798*da0073e9SAndroid Build Coastguard Worker with_modules=True, 799*da0073e9SAndroid Build Coastguard Worker ) as prof: 800*da0073e9SAndroid Build Coastguard Worker model(input_a, input_b) 801*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 802*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 803*da0073e9SAndroid Build Coastguard Worker trace = json.load(f) 804*da0073e9SAndroid Build Coastguard Worker assert "traceEvents" in trace 805*da0073e9SAndroid Build Coastguard Worker events = trace["traceEvents"] 806*da0073e9SAndroid Build Coastguard Worker found_memory_events = False 807*da0073e9SAndroid Build Coastguard Worker for evt in events: 808*da0073e9SAndroid Build Coastguard Worker assert "name" in evt 809*da0073e9SAndroid Build Coastguard Worker if "args" in evt: 810*da0073e9SAndroid Build Coastguard Worker op_name = evt["name"] 811*da0073e9SAndroid Build Coastguard Worker if "Module Hierarchy" in evt["args"]: 812*da0073e9SAndroid Build Coastguard Worker hierarchy = evt["args"]["Module Hierarchy"] 813*da0073e9SAndroid Build Coastguard Worker if op_name in op_to_module_hierarchy: 814*da0073e9SAndroid Build Coastguard Worker assert hierarchy in op_to_module_hierarchy[op_name] 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker def test_high_level_trace(self): 817*da0073e9SAndroid Build Coastguard Worker """Checks that python side high level events are recorded.""" 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker class RepeatedDataset(torch.utils.data.Dataset): 820*da0073e9SAndroid Build Coastguard Worker def __init__(self, N, D_in, D_out): 821*da0073e9SAndroid Build Coastguard Worker self.N = N 822*da0073e9SAndroid Build Coastguard Worker self.x = torch.randn(N, D_in) 823*da0073e9SAndroid Build Coastguard Worker self.y = torch.randn(N, D_out) 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker def __len__(self): 826*da0073e9SAndroid Build Coastguard Worker return self.N 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 829*da0073e9SAndroid Build Coastguard Worker return self.x, self.y 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker class TwoLayerNet(torch.nn.Module): 832*da0073e9SAndroid Build Coastguard Worker def __init__(self, D_in, H, D_out): 833*da0073e9SAndroid Build Coastguard Worker super().__init__() 834*da0073e9SAndroid Build Coastguard Worker self.linear1 = torch.nn.Linear(D_in, H) 835*da0073e9SAndroid Build Coastguard Worker self.linear2 = torch.nn.Linear(H, D_out) 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 838*da0073e9SAndroid Build Coastguard Worker h_relu = self.linear1(x).clamp(min=0) 839*da0073e9SAndroid Build Coastguard Worker y_pred = self.linear2(h_relu) 840*da0073e9SAndroid Build Coastguard Worker return y_pred 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker class CustomSGD(torch.optim.SGD): 843*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args, **kwargs): 844*da0073e9SAndroid Build Coastguard Worker super().__init__(*args, **kwargs) 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker def train(): 847*da0073e9SAndroid Build Coastguard Worker for _, data in enumerate(dataloader): 848*da0073e9SAndroid Build Coastguard Worker x, y = data[0], data[1] 849*da0073e9SAndroid Build Coastguard Worker y_pred = model(x) 850*da0073e9SAndroid Build Coastguard Worker loss = criterion(y_pred, y) 851*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 852*da0073e9SAndroid Build Coastguard Worker loss.backward() 853*da0073e9SAndroid Build Coastguard Worker optimizer.step() 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker N, D_in, H, D_out = 8, 10, 5, 2 856*da0073e9SAndroid Build Coastguard Worker model = TwoLayerNet(D_in, H, D_out) 857*da0073e9SAndroid Build Coastguard Worker criterion = torch.nn.MSELoss(reduction="sum") 858*da0073e9SAndroid Build Coastguard Worker optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) 859*da0073e9SAndroid Build Coastguard Worker ds = RepeatedDataset(N, D_in, D_out) 860*da0073e9SAndroid Build Coastguard Worker dataloader = torch.utils.data.DataLoader(ds, batch_size=1) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker try: 863*da0073e9SAndroid Build Coastguard Worker train() 864*da0073e9SAndroid Build Coastguard Worker except Exception: 865*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False, "Expected no exception without profiling.") 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker # Create multiple instances, expect each func is hooked only one time. 868*da0073e9SAndroid Build Coastguard Worker # Nested wrappers(repeated patching) will make following test fail. 869*da0073e9SAndroid Build Coastguard Worker optimizer_duplicate = torch.optim.SGD(model.parameters(), lr=1e-4) 870*da0073e9SAndroid Build Coastguard Worker dataloader_duplicate = torch.utils.data.DataLoader(ds, batch_size=1) 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker def judge(expected_event_count, prof): 873*da0073e9SAndroid Build Coastguard Worker actual_event_count = {} 874*da0073e9SAndroid Build Coastguard Worker for e in prof.function_events: 875*da0073e9SAndroid Build Coastguard Worker if "#" in e.name: 876*da0073e9SAndroid Build Coastguard Worker key = e.name 877*da0073e9SAndroid Build Coastguard Worker if key in expected_event_count.keys(): 878*da0073e9SAndroid Build Coastguard Worker actual_event_count[key] = ( 879*da0073e9SAndroid Build Coastguard Worker actual_event_count.setdefault(key, 0) + 1 880*da0073e9SAndroid Build Coastguard Worker ) 881*da0073e9SAndroid Build Coastguard Worker for key, count in expected_event_count.items(): 882*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 883*da0073e9SAndroid Build Coastguard Worker (key in actual_event_count.keys()) 884*da0073e9SAndroid Build Coastguard Worker and (count == actual_event_count[key]) 885*da0073e9SAndroid Build Coastguard Worker ) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker with _profile(use_kineto=kineto_available()) as prof: 888*da0073e9SAndroid Build Coastguard Worker train() 889*da0073e9SAndroid Build Coastguard Worker expected_event_count = { 890*da0073e9SAndroid Build Coastguard Worker # "+1" because the final iteration will enter __next__ but skip the loop body. 891*da0073e9SAndroid Build Coastguard Worker "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1), 892*da0073e9SAndroid Build Coastguard Worker "Optimizer.step#SGD.step": N, 893*da0073e9SAndroid Build Coastguard Worker "Optimizer.zero_grad#SGD.zero_grad": N, 894*da0073e9SAndroid Build Coastguard Worker } 895*da0073e9SAndroid Build Coastguard Worker judge(expected_event_count, prof) 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker # Test on pickle/unpickle. Expect to work in multi-processing. 898*da0073e9SAndroid Build Coastguard Worker optimizer = pickle.loads(pickle.dumps(optimizer)) 899*da0073e9SAndroid Build Coastguard Worker with _profile(use_kineto=kineto_available()) as prof: 900*da0073e9SAndroid Build Coastguard Worker train() 901*da0073e9SAndroid Build Coastguard Worker judge(expected_event_count, prof) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker # Test on customized optimizer. 904*da0073e9SAndroid Build Coastguard Worker optimizer = CustomSGD(model.parameters(), lr=1e-4) 905*da0073e9SAndroid Build Coastguard Worker with _profile(use_kineto=kineto_available()) as prof: 906*da0073e9SAndroid Build Coastguard Worker train() 907*da0073e9SAndroid Build Coastguard Worker expected_event_count = { 908*da0073e9SAndroid Build Coastguard Worker "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1), 909*da0073e9SAndroid Build Coastguard Worker "Optimizer.step#CustomSGD.step": N, 910*da0073e9SAndroid Build Coastguard Worker "Optimizer.zero_grad#CustomSGD.zero_grad": N, 911*da0073e9SAndroid Build Coastguard Worker } 912*da0073e9SAndroid Build Coastguard Worker judge(expected_event_count, prof) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker def test_flops(self): 915*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 916*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(16, 33, 18), 917*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 918*da0073e9SAndroid Build Coastguard Worker nn.Linear(243, 243), 919*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 920*da0073e9SAndroid Build Coastguard Worker ) 921*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(40, 16, 18, 260) 922*da0073e9SAndroid Build Coastguard Worker nested_tensor = torch.nested.nested_tensor( 923*da0073e9SAndroid Build Coastguard Worker [torch.randn((2, 5)), torch.randn((3, 5))], layout=torch.jagged 924*da0073e9SAndroid Build Coastguard Worker ) 925*da0073e9SAndroid Build Coastguard Worker with _profile( 926*da0073e9SAndroid Build Coastguard Worker record_shapes=True, with_flops=True, use_kineto=kineto_available() 927*da0073e9SAndroid Build Coastguard Worker ) as prof: 928*da0073e9SAndroid Build Coastguard Worker model(inputs) 929*da0073e9SAndroid Build Coastguard Worker # test that nested tensor won't cause exception during flop compute 930*da0073e9SAndroid Build Coastguard Worker nested_tensor = nested_tensor + nested_tensor 931*da0073e9SAndroid Build Coastguard Worker profiler_output = prof.key_averages(group_by_input_shape=True).table( 932*da0073e9SAndroid Build Coastguard Worker sort_by="cpu_time_total", row_limit=10 933*da0073e9SAndroid Build Coastguard Worker ) 934*da0073e9SAndroid Build Coastguard Worker self.assertIn("Total MFLOPs", profiler_output) 935*da0073e9SAndroid Build Coastguard Worker if not (kineto_available() and torch.cuda.is_available()): 936*da0073e9SAndroid Build Coastguard Worker return 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker with profile( 939*da0073e9SAndroid Build Coastguard Worker activities=[ 940*da0073e9SAndroid Build Coastguard Worker torch.profiler.ProfilerActivity.CPU, 941*da0073e9SAndroid Build Coastguard Worker torch.profiler.ProfilerActivity.CUDA, 942*da0073e9SAndroid Build Coastguard Worker ], 943*da0073e9SAndroid Build Coastguard Worker record_shapes=True, 944*da0073e9SAndroid Build Coastguard Worker with_flops=True, 945*da0073e9SAndroid Build Coastguard Worker ) as kineto_profiler: 946*da0073e9SAndroid Build Coastguard Worker model(inputs) 947*da0073e9SAndroid Build Coastguard Worker profiler_output = kineto_profiler.key_averages().table( 948*da0073e9SAndroid Build Coastguard Worker sort_by="self_cuda_time_total", row_limit=-1 949*da0073e9SAndroid Build Coastguard Worker ) 950*da0073e9SAndroid Build Coastguard Worker self.assertIn("Total MFLOPs", profiler_output) 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker def test_kineto_profiler_api(self): 953*da0073e9SAndroid Build Coastguard Worker called_num = [0] 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 956*da0073e9SAndroid Build Coastguard Worker with profile(activities=supported_activities()): 957*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker def trace_handler(p): 960*da0073e9SAndroid Build Coastguard Worker output = p.key_averages().table( 961*da0073e9SAndroid Build Coastguard Worker sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total", 962*da0073e9SAndroid Build Coastguard Worker row_limit=-1, 963*da0073e9SAndroid Build Coastguard Worker ) 964*da0073e9SAndroid Build Coastguard Worker # print(output) 965*da0073e9SAndroid Build Coastguard Worker # p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json") 966*da0073e9SAndroid Build Coastguard Worker called_num[0] += 1 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker initial_step = KinetoStepTracker.current_step() 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker with profile( 971*da0073e9SAndroid Build Coastguard Worker activities=supported_activities(), 972*da0073e9SAndroid Build Coastguard Worker schedule=torch.profiler.schedule(wait=1, warmup=1, active=2), 973*da0073e9SAndroid Build Coastguard Worker on_trace_ready=trace_handler, 974*da0073e9SAndroid Build Coastguard Worker ) as p: 975*da0073e9SAndroid Build Coastguard Worker for idx in range(8): 976*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 977*da0073e9SAndroid Build Coastguard Worker p.step() 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called_num[0], 2) 980*da0073e9SAndroid Build Coastguard Worker self.assertEqual(KinetoStepTracker.current_step(), initial_step + 8) 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker # case without schedule 983*da0073e9SAndroid Build Coastguard Worker with profile(activities=supported_activities()) as p: 984*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 985*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 986*da0073e9SAndroid Build Coastguard Worker output = p.key_averages().table( 987*da0073e9SAndroid Build Coastguard Worker sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total", 988*da0073e9SAndroid Build Coastguard Worker row_limit=-1, 989*da0073e9SAndroid Build Coastguard Worker ) 990*da0073e9SAndroid Build Coastguard Worker # print(output) 991*da0073e9SAndroid Build Coastguard Worker 992*da0073e9SAndroid Build Coastguard Worker test_schedule = torch.profiler.schedule( 993*da0073e9SAndroid Build Coastguard Worker skip_first=2, wait=1, warmup=1, active=2, repeat=2 994*da0073e9SAndroid Build Coastguard Worker ) 995*da0073e9SAndroid Build Coastguard Worker test_schedule_expected_outputs = [ 996*da0073e9SAndroid Build Coastguard Worker ProfilerAction.NONE, 997*da0073e9SAndroid Build Coastguard Worker ProfilerAction.NONE, 998*da0073e9SAndroid Build Coastguard Worker ProfilerAction.NONE, 999*da0073e9SAndroid Build Coastguard Worker ProfilerAction.WARMUP, 1000*da0073e9SAndroid Build Coastguard Worker ProfilerAction.RECORD, 1001*da0073e9SAndroid Build Coastguard Worker ProfilerAction.RECORD_AND_SAVE, 1002*da0073e9SAndroid Build Coastguard Worker ProfilerAction.NONE, 1003*da0073e9SAndroid Build Coastguard Worker ProfilerAction.WARMUP, 1004*da0073e9SAndroid Build Coastguard Worker ProfilerAction.RECORD, 1005*da0073e9SAndroid Build Coastguard Worker ProfilerAction.RECORD_AND_SAVE, 1006*da0073e9SAndroid Build Coastguard Worker ProfilerAction.NONE, 1007*da0073e9SAndroid Build Coastguard Worker ProfilerAction.NONE, 1008*da0073e9SAndroid Build Coastguard Worker ProfilerAction.NONE, 1009*da0073e9SAndroid Build Coastguard Worker ProfilerAction.NONE, 1010*da0073e9SAndroid Build Coastguard Worker ] 1011*da0073e9SAndroid Build Coastguard Worker for step in range(len(test_schedule_expected_outputs)): 1012*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step]) 1013*da0073e9SAndroid Build Coastguard Worker 1014*da0073e9SAndroid Build Coastguard Worker def test_kineto_profiler_multiple_steppers(self): 1015*da0073e9SAndroid Build Coastguard Worker niters = 8 1016*da0073e9SAndroid Build Coastguard Worker use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 1017*da0073e9SAndroid Build Coastguard Worker net = SimpleNet() 1018*da0073e9SAndroid Build Coastguard Worker opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) 1019*da0073e9SAndroid Build Coastguard Worker opt.zero_grad() 1020*da0073e9SAndroid Build Coastguard Worker inputs = torch.rand(10) 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker with profile(activities=supported_activities()): 1023*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker def optimizer_step(): 1026*da0073e9SAndroid Build Coastguard Worker """This simulates a step() hook in the optimizer""" 1027*da0073e9SAndroid Build Coastguard Worker KinetoStepTracker.increment_step("yet_another_step") 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker initial_step = KinetoStepTracker.current_step() 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker def run_batch(): 1032*da0073e9SAndroid Build Coastguard Worker out = net(inputs) 1033*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.cross_entropy(out, torch.rand(2)) 1034*da0073e9SAndroid Build Coastguard Worker loss.backward() 1035*da0073e9SAndroid Build Coastguard Worker opt.step() 1036*da0073e9SAndroid Build Coastguard Worker # Manually call the hook. TODO: Remove this once we add the 1037*da0073e9SAndroid Build Coastguard Worker # profiler step hooks in the Optimizer class that will get triggered above. 1038*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/88446 1039*da0073e9SAndroid Build Coastguard Worker optimizer_step() 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker for idx in range(niters): 1042*da0073e9SAndroid Build Coastguard Worker run_batch() 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker with profile( 1045*da0073e9SAndroid Build Coastguard Worker activities=supported_activities(), 1046*da0073e9SAndroid Build Coastguard Worker schedule=torch.profiler.schedule(wait=1, warmup=1, active=2), 1047*da0073e9SAndroid Build Coastguard Worker ) as p: 1048*da0073e9SAndroid Build Coastguard Worker for idx in range(niters): 1049*da0073e9SAndroid Build Coastguard Worker run_batch() 1050*da0073e9SAndroid Build Coastguard Worker p.step() 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker self.assertEqual(KinetoStepTracker.current_step(), initial_step + 2 * niters) 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker def test_export_stacks(self): 1055*da0073e9SAndroid Build Coastguard Worker with _profile( 1056*da0073e9SAndroid Build Coastguard Worker with_stack=True, 1057*da0073e9SAndroid Build Coastguard Worker use_kineto=kineto_available(), 1058*da0073e9SAndroid Build Coastguard Worker experimental_config=_ExperimentalConfig(verbose=True), 1059*da0073e9SAndroid Build Coastguard Worker ) as p: 1060*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 1061*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10) 1062*da0073e9SAndroid Build Coastguard Worker z = torch.mm(x, y) 1063*da0073e9SAndroid Build Coastguard Worker z = z + y 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1066*da0073e9SAndroid Build Coastguard Worker p.export_stacks(fname) 1067*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1068*da0073e9SAndroid Build Coastguard Worker lines = f.readlines() 1069*da0073e9SAndroid Build Coastguard Worker assert len(lines) > 0, "Empty stacks file" 1070*da0073e9SAndroid Build Coastguard Worker for line in lines: 1071*da0073e9SAndroid Build Coastguard Worker is_int = False 1072*da0073e9SAndroid Build Coastguard Worker try: 1073*da0073e9SAndroid Build Coastguard Worker assert int(line.split(" ")[-1]) > 0, "Invalid stacks record" 1074*da0073e9SAndroid Build Coastguard Worker is_int = True 1075*da0073e9SAndroid Build Coastguard Worker except ValueError: 1076*da0073e9SAndroid Build Coastguard Worker pass 1077*da0073e9SAndroid Build Coastguard Worker assert is_int, "Invalid stacks record" 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not kineto_available(), "Kineto is required") 1080*da0073e9SAndroid Build Coastguard Worker def test_tensorboard_trace_handler(self): 1081*da0073e9SAndroid Build Coastguard Worker use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 1082*da0073e9SAndroid Build Coastguard Worker with _profile(use_cuda=use_cuda, use_kineto=True): 1083*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker with TemporaryDirectoryName() as dname: 1086*da0073e9SAndroid Build Coastguard Worker with profile( 1087*da0073e9SAndroid Build Coastguard Worker activities=[torch.profiler.ProfilerActivity.CPU] 1088*da0073e9SAndroid Build Coastguard Worker + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []), 1089*da0073e9SAndroid Build Coastguard Worker schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3), 1090*da0073e9SAndroid Build Coastguard Worker on_trace_ready=torch.profiler.tensorboard_trace_handler(dname), 1091*da0073e9SAndroid Build Coastguard Worker ) as p: 1092*da0073e9SAndroid Build Coastguard Worker for _ in range(18): 1093*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 1094*da0073e9SAndroid Build Coastguard Worker p.step() 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker self.assertTrue(os.path.exists(dname)) 1097*da0073e9SAndroid Build Coastguard Worker file_num = 0 1098*da0073e9SAndroid Build Coastguard Worker for file_name in os.listdir(dname): 1099*da0073e9SAndroid Build Coastguard Worker parts = file_name.split(".") 1100*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(parts) > 4) 1101*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1102*da0073e9SAndroid Build Coastguard Worker parts[-4].isdigit() and int(parts[-4]) > 0, 1103*da0073e9SAndroid Build Coastguard Worker "Wrong tracing file name pattern", 1104*da0073e9SAndroid Build Coastguard Worker ) 1105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parts[-3:], ["pt", "trace", "json"]) 1106*da0073e9SAndroid Build Coastguard Worker file_num += 1 1107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(file_num, 3) 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard Worker # test case for gzip file format 1110*da0073e9SAndroid Build Coastguard Worker with TemporaryDirectoryName() as dname: 1111*da0073e9SAndroid Build Coastguard Worker p = profile( 1112*da0073e9SAndroid Build Coastguard Worker activities=[torch.profiler.ProfilerActivity.CPU] 1113*da0073e9SAndroid Build Coastguard Worker + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []), 1114*da0073e9SAndroid Build Coastguard Worker schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3), 1115*da0073e9SAndroid Build Coastguard Worker on_trace_ready=torch.profiler.tensorboard_trace_handler( 1116*da0073e9SAndroid Build Coastguard Worker dname, use_gzip=True 1117*da0073e9SAndroid Build Coastguard Worker ), 1118*da0073e9SAndroid Build Coastguard Worker ) 1119*da0073e9SAndroid Build Coastguard Worker p.start() 1120*da0073e9SAndroid Build Coastguard Worker for _ in range(18): 1121*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 1122*da0073e9SAndroid Build Coastguard Worker p.step() 1123*da0073e9SAndroid Build Coastguard Worker p.stop() 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker self.assertTrue(os.path.exists(dname)) 1126*da0073e9SAndroid Build Coastguard Worker file_num = 0 1127*da0073e9SAndroid Build Coastguard Worker for file_name in os.listdir(dname): 1128*da0073e9SAndroid Build Coastguard Worker parts = file_name.split(".") 1129*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(parts) > 4) 1130*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1131*da0073e9SAndroid Build Coastguard Worker parts[-5].isdigit() and int(parts[-5]) > 0, 1132*da0073e9SAndroid Build Coastguard Worker "Wrong tracing file name pattern", 1133*da0073e9SAndroid Build Coastguard Worker ) 1134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(parts[-4:], ["pt", "trace", "json", "gz"]) 1135*da0073e9SAndroid Build Coastguard Worker file_num += 1 1136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(file_num, 3) 1137*da0073e9SAndroid Build Coastguard Worker 1138*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not kineto_available(), "Kineto is required") 1139*da0073e9SAndroid Build Coastguard Worker def test_profiler_metadata(self): 1140*da0073e9SAndroid Build Coastguard Worker t1, t2 = torch.ones(1), torch.ones(1) 1141*da0073e9SAndroid Build Coastguard Worker with profile() as prof: 1142*da0073e9SAndroid Build Coastguard Worker torch.add(t1, t2) 1143*da0073e9SAndroid Build Coastguard Worker prof.add_metadata("test_key1", "test_value1") 1144*da0073e9SAndroid Build Coastguard Worker prof.add_metadata_json("test_key2", "[1,2,3]") 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1147*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1148*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1149*da0073e9SAndroid Build Coastguard Worker trace = json.load(f) 1150*da0073e9SAndroid Build Coastguard Worker assert "test_key1" in trace 1151*da0073e9SAndroid Build Coastguard Worker assert trace["test_key1"] == "test_value1" 1152*da0073e9SAndroid Build Coastguard Worker assert "test_key2" in trace 1153*da0073e9SAndroid Build Coastguard Worker assert trace["test_key2"] == [1, 2, 3] 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker def _test_profiler_tracing(self, use_kineto): 1156*da0073e9SAndroid Build Coastguard Worker with _profile(use_kineto=use_kineto) as prof: 1157*da0073e9SAndroid Build Coastguard Worker t1, t2 = torch.ones(1), torch.ones(1) 1158*da0073e9SAndroid Build Coastguard Worker torch.add(t1, t2) 1159*da0073e9SAndroid Build Coastguard Worker 1160*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1161*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1162*da0073e9SAndroid Build Coastguard Worker # read the trace and expect valid json 1163*da0073e9SAndroid Build Coastguard Worker # if the JSON generated by export_chrome_trace is not valid, this will throw and fail the test. 1164*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1165*da0073e9SAndroid Build Coastguard Worker json.load(f) 1166*da0073e9SAndroid Build Coastguard Worker 1167*da0073e9SAndroid Build Coastguard Worker # test empty trace 1168*da0073e9SAndroid Build Coastguard Worker with _profile(use_kineto=use_kineto) as prof: 1169*da0073e9SAndroid Build Coastguard Worker pass 1170*da0073e9SAndroid Build Coastguard Worker # saving an empty trace 1171*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1172*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1173*da0073e9SAndroid Build Coastguard Worker if use_kineto: 1174*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1175*da0073e9SAndroid Build Coastguard Worker contents = json.load(f) 1176*da0073e9SAndroid Build Coastguard Worker # Some builds may not have logger observer 1177*da0073e9SAndroid Build Coastguard Worker # so skip if not 1178*da0073e9SAndroid Build Coastguard Worker if "WARNING" in contents: 1179*da0073e9SAndroid Build Coastguard Worker found_empty_warning = False 1180*da0073e9SAndroid Build Coastguard Worker for warning in contents["WARNING"]: 1181*da0073e9SAndroid Build Coastguard Worker if "No Valid Trace Events" in warning: 1182*da0073e9SAndroid Build Coastguard Worker found_empty_warning = True 1183*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_empty_warning) 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker # Same test but for cuda. 1186*da0073e9SAndroid Build Coastguard Worker use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 1187*da0073e9SAndroid Build Coastguard Worker if not use_cuda: 1188*da0073e9SAndroid Build Coastguard Worker return 1189*da0073e9SAndroid Build Coastguard Worker 1190*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:0") 1191*da0073e9SAndroid Build Coastguard Worker with _profile(use_cuda=True, use_kineto=use_kineto) as prof: 1192*da0073e9SAndroid Build Coastguard Worker t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device) 1193*da0073e9SAndroid Build Coastguard Worker torch.add(t1, t2) 1194*da0073e9SAndroid Build Coastguard Worker 1195*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1196*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1197*da0073e9SAndroid Build Coastguard Worker # Now validate the json 1198*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1199*da0073e9SAndroid Build Coastguard Worker json.load(f) 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker def test_profiler_tracing(self): 1202*da0073e9SAndroid Build Coastguard Worker self._test_profiler_tracing(False) 1203*da0073e9SAndroid Build Coastguard Worker if kineto_available(): 1204*da0073e9SAndroid Build Coastguard Worker self._test_profiler_tracing(True) 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker def test_profiler_op_event_args(self): 1207*da0073e9SAndroid Build Coastguard Worker torch._C._profiler._set_record_concrete_inputs_enabled_val(True) 1208*da0073e9SAndroid Build Coastguard Worker with _profile(record_shapes=True) as prof: 1209*da0073e9SAndroid Build Coastguard Worker a = torch.ones((64, 32), dtype=torch.float32) 1210*da0073e9SAndroid Build Coastguard Worker c = torch.cat([a, a]).sin() 1211*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1212*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1213*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1214*da0073e9SAndroid Build Coastguard Worker j = json.load(f) 1215*da0073e9SAndroid Build Coastguard Worker op_events = [ 1216*da0073e9SAndroid Build Coastguard Worker e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op" 1217*da0073e9SAndroid Build Coastguard Worker ] 1218*da0073e9SAndroid Build Coastguard Worker for e in op_events: 1219*da0073e9SAndroid Build Coastguard Worker args = e["args"] 1220*da0073e9SAndroid Build Coastguard Worker if e["name"] == "aten::ones": 1221*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1222*da0073e9SAndroid Build Coastguard Worker args["Input type"], 1223*da0073e9SAndroid Build Coastguard Worker ["ScalarList", "Scalar", "", "", "Scalar"], 1224*da0073e9SAndroid Build Coastguard Worker ) 1225*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1226*da0073e9SAndroid Build Coastguard Worker args["Concrete Inputs"], ["[64, 32]", "6", "", "", "False"] 1227*da0073e9SAndroid Build Coastguard Worker ) 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker if e["name"] == "aten::cat": 1230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(args["Input Dims"], [[[64, 32], [64, 32]], []]) 1231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(args["Input type"], ["TensorList", "Scalar"]) 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker # check that each op has record function id 1234*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1235*da0073e9SAndroid Build Coastguard Worker args.get("Record function id", -1), 1236*da0073e9SAndroid Build Coastguard Worker 0, 1237*da0073e9SAndroid Build Coastguard Worker f"Failed finding record funciont for op = {e}", 1238*da0073e9SAndroid Build Coastguard Worker ) 1239*da0073e9SAndroid Build Coastguard Worker 1240*da0073e9SAndroid Build Coastguard Worker def test_profiler_strides(self): 1241*da0073e9SAndroid Build Coastguard Worker torch._C._profiler._set_record_concrete_inputs_enabled_val(True) 1242*da0073e9SAndroid Build Coastguard Worker base_tensor = torch.randn(1024, dtype=torch.float32) 1243*da0073e9SAndroid Build Coastguard Worker a = base_tensor.as_strided((16, 16), (17, 1), 0) 1244*da0073e9SAndroid Build Coastguard Worker b = base_tensor.as_strided((16, 16), (25, 2), 272) 1245*da0073e9SAndroid Build Coastguard Worker with _profile(record_shapes=True) as prof: 1246*da0073e9SAndroid Build Coastguard Worker c = torch.add(a, b) 1247*da0073e9SAndroid Build Coastguard Worker 1248*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1249*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1250*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1251*da0073e9SAndroid Build Coastguard Worker j = json.load(f) 1252*da0073e9SAndroid Build Coastguard Worker op_events = [ 1253*da0073e9SAndroid Build Coastguard Worker e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op" 1254*da0073e9SAndroid Build Coastguard Worker ] 1255*da0073e9SAndroid Build Coastguard Worker for e in op_events: 1256*da0073e9SAndroid Build Coastguard Worker args = e["args"] 1257*da0073e9SAndroid Build Coastguard Worker if e["name"] == "aten::add": 1258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(args["Input Strides"], [[17, 1], [25, 2], []]) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker def test_profiler_fwd_bwd_link(self): 1261*da0073e9SAndroid Build Coastguard Worker with _profile(use_kineto=True) as prof: 1262*da0073e9SAndroid Build Coastguard Worker t1, t2 = torch.ones(1, requires_grad=True), torch.ones( 1263*da0073e9SAndroid Build Coastguard Worker 1, requires_grad=True 1264*da0073e9SAndroid Build Coastguard Worker ) 1265*da0073e9SAndroid Build Coastguard Worker z = torch.add(t1, t2) 1266*da0073e9SAndroid Build Coastguard Worker y = torch.ones(1) 1267*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) 1268*da0073e9SAndroid Build Coastguard Worker loss.backward() 1269*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1270*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1271*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1272*da0073e9SAndroid Build Coastguard Worker j = json.load(f) 1273*da0073e9SAndroid Build Coastguard Worker events = j["traceEvents"] 1274*da0073e9SAndroid Build Coastguard Worker ts_to_name = {} 1275*da0073e9SAndroid Build Coastguard Worker flow_s_to_ts = {} 1276*da0073e9SAndroid Build Coastguard Worker flow_f_to_ts = {} 1277*da0073e9SAndroid Build Coastguard Worker for e in events: 1278*da0073e9SAndroid Build Coastguard Worker if e["ph"] == "X": 1279*da0073e9SAndroid Build Coastguard Worker ts_to_name[e["ts"]] = e["name"] 1280*da0073e9SAndroid Build Coastguard Worker if ( 1281*da0073e9SAndroid Build Coastguard Worker "cat" in e 1282*da0073e9SAndroid Build Coastguard Worker and "name" in e 1283*da0073e9SAndroid Build Coastguard Worker and e["cat"] == "fwdbwd" 1284*da0073e9SAndroid Build Coastguard Worker and e["name"] == "fwdbwd" 1285*da0073e9SAndroid Build Coastguard Worker ): 1286*da0073e9SAndroid Build Coastguard Worker if e["ph"] == "s": 1287*da0073e9SAndroid Build Coastguard Worker flow_s_to_ts[e["id"]] = e["ts"] 1288*da0073e9SAndroid Build Coastguard Worker elif e["ph"] == "f": 1289*da0073e9SAndroid Build Coastguard Worker flow_f_to_ts[e["id"]] = e["ts"] 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(flow_s_to_ts), 2) 1292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(flow_f_to_ts), 2) 1293*da0073e9SAndroid Build Coastguard Worker self.assertIn(1, flow_s_to_ts) 1294*da0073e9SAndroid Build Coastguard Worker self.assertIn(1, flow_f_to_ts) 1295*da0073e9SAndroid Build Coastguard Worker self.assertIn(2, flow_s_to_ts) 1296*da0073e9SAndroid Build Coastguard Worker self.assertIn(2, flow_f_to_ts) 1297*da0073e9SAndroid Build Coastguard Worker s_ts_1 = flow_s_to_ts[1] 1298*da0073e9SAndroid Build Coastguard Worker f_ts_1 = flow_f_to_ts[1] 1299*da0073e9SAndroid Build Coastguard Worker s_ts_2 = flow_s_to_ts[2] 1300*da0073e9SAndroid Build Coastguard Worker f_ts_2 = flow_f_to_ts[2] 1301*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1302*da0073e9SAndroid Build Coastguard Worker all( 1303*da0073e9SAndroid Build Coastguard Worker ts in ts_to_name.keys() 1304*da0073e9SAndroid Build Coastguard Worker for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2] 1305*da0073e9SAndroid Build Coastguard Worker ) 1306*da0073e9SAndroid Build Coastguard Worker ) 1307*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1308*da0073e9SAndroid Build Coastguard Worker ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits" 1309*da0073e9SAndroid Build Coastguard Worker ) 1310*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ts_to_name[s_ts_2] == "aten::add") 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker def test_profiler_disable_fwd_bwd_link(self): 1313*da0073e9SAndroid Build Coastguard Worker try: 1314*da0073e9SAndroid Build Coastguard Worker torch._C._profiler._set_fwd_bwd_enabled_val(False) 1315*da0073e9SAndroid Build Coastguard Worker 1316*da0073e9SAndroid Build Coastguard Worker with _profile(use_kineto=True) as prof: 1317*da0073e9SAndroid Build Coastguard Worker t1, t2 = torch.ones(1, requires_grad=True), torch.ones( 1318*da0073e9SAndroid Build Coastguard Worker 1, requires_grad=True 1319*da0073e9SAndroid Build Coastguard Worker ) 1320*da0073e9SAndroid Build Coastguard Worker z = torch.add(t1, t2) 1321*da0073e9SAndroid Build Coastguard Worker y = torch.ones(1) 1322*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) 1323*da0073e9SAndroid Build Coastguard Worker loss.backward() 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1326*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1327*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1328*da0073e9SAndroid Build Coastguard Worker j = json.load(f) 1329*da0073e9SAndroid Build Coastguard Worker events = j["traceEvents"] 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Worker for e in events: 1332*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(e.get("cat", None), "fwdbwd") 1333*da0073e9SAndroid Build Coastguard Worker finally: 1334*da0073e9SAndroid Build Coastguard Worker torch._C._profiler._set_fwd_bwd_enabled_val(True) 1335*da0073e9SAndroid Build Coastguard Worker 1336*da0073e9SAndroid Build Coastguard Worker # This test is broken on Windows, the likely reason is that kineto/CUPTI 1337*da0073e9SAndroid Build Coastguard Worker # is not supported that particular environment. Once the CI stabilizes 1338*da0073e9SAndroid Build Coastguard Worker # we can narrow the condition so Windows is checked as well (TODO) 1339*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not kineto_available(), "Kineto is required") 1340*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Test does not work on Windows") 1341*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") 1342*da0073e9SAndroid Build Coastguard Worker def test_profiler_cuda_sync_events(self): 1343*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:0") 1344*da0073e9SAndroid Build Coastguard Worker t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device) 1345*da0073e9SAndroid Build Coastguard Worker 1346*da0073e9SAndroid Build Coastguard Worker def workload() -> None: 1347*da0073e9SAndroid Build Coastguard Worker torch.add(t1, t2) 1348*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1349*da0073e9SAndroid Build Coastguard Worker torch.add(t1, t2) 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker def trace_and_check(exp_config: Optional[_ExperimentalConfig]) -> None: 1352*da0073e9SAndroid Build Coastguard Worker with _profile( 1353*da0073e9SAndroid Build Coastguard Worker use_kineto=True, 1354*da0073e9SAndroid Build Coastguard Worker use_cuda=True, 1355*da0073e9SAndroid Build Coastguard Worker experimental_config=exp_config, 1356*da0073e9SAndroid Build Coastguard Worker ) as prof: 1357*da0073e9SAndroid Build Coastguard Worker workload() 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1360*da0073e9SAndroid Build Coastguard Worker # fname = "/tmp/kineto_out.json" 1361*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1362*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1363*da0073e9SAndroid Build Coastguard Worker j = json.load(f) 1364*da0073e9SAndroid Build Coastguard Worker cats = {e.get("cat", None) for e in j["traceEvents"]} 1365*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1366*da0073e9SAndroid Build Coastguard Worker "cuda_sync" in cats, 1367*da0073e9SAndroid Build Coastguard Worker "Expected to find cuda_sync event" f" found = {cats}", 1368*da0073e9SAndroid Build Coastguard Worker ) 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker print("Testing enable_cuda_sync_events in _ExperimentalConfig") 1371*da0073e9SAndroid Build Coastguard Worker trace_and_check(exp_config=_ExperimentalConfig(enable_cuda_sync_events=True)) 1372*da0073e9SAndroid Build Coastguard Worker 1373*da0073e9SAndroid Build Coastguard Worker print("Testing _profiler._set_cuda_sync_enabled_val()") 1374*da0073e9SAndroid Build Coastguard Worker try: 1375*da0073e9SAndroid Build Coastguard Worker torch._C._profiler._set_cuda_sync_enabled_val(True) 1376*da0073e9SAndroid Build Coastguard Worker trace_and_check(exp_config=None) 1377*da0073e9SAndroid Build Coastguard Worker finally: 1378*da0073e9SAndroid Build Coastguard Worker torch._C._profiler._set_cuda_sync_enabled_val(False) 1379*da0073e9SAndroid Build Coastguard Worker 1380*da0073e9SAndroid Build Coastguard Worker def test_profiler_type(self): 1381*da0073e9SAndroid Build Coastguard Worker profiler_type = torch._C._autograd._profiler_type 1382*da0073e9SAndroid Build Coastguard Worker ActiveProfilerType = torch._C._profiler.ActiveProfilerType 1383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(profiler_type(), ActiveProfilerType.NONE) 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker # Autograd profiler 1386*da0073e9SAndroid Build Coastguard Worker with _profile_legacy(): 1387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(profiler_type(), ActiveProfilerType.LEGACY) 1388*da0073e9SAndroid Build Coastguard Worker 1389*da0073e9SAndroid Build Coastguard Worker # Kineto profiler 1390*da0073e9SAndroid Build Coastguard Worker with profile(): 1391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(profiler_type(), ActiveProfilerType.KINETO) 1392*da0073e9SAndroid Build Coastguard Worker 1393*da0073e9SAndroid Build Coastguard Worker def test_profiler_correlation_id(self): 1394*da0073e9SAndroid Build Coastguard Worker """ 1395*da0073e9SAndroid Build Coastguard Worker We expect the correlation_id to be unique across multiple invokation of the profiler, 1396*da0073e9SAndroid Build Coastguard Worker So we will reuse id_uniqueness_set. 1397*da0073e9SAndroid Build Coastguard Worker """ 1398*da0073e9SAndroid Build Coastguard Worker id_uniqueness_set = set() 1399*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Sequential( 1400*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(16, 33, 18), 1401*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 1402*da0073e9SAndroid Build Coastguard Worker nn.Linear(243, 243), 1403*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 1404*da0073e9SAndroid Build Coastguard Worker ) 1405*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(40, 16, 18, 260) 1406*da0073e9SAndroid Build Coastguard Worker uint32_max = 2**32 - 1 1407*da0073e9SAndroid Build Coastguard Worker for i in range(5): 1408*da0073e9SAndroid Build Coastguard Worker with profile() as prof: 1409*da0073e9SAndroid Build Coastguard Worker model(inputs) 1410*da0073e9SAndroid Build Coastguard Worker for event in prof.profiler.kineto_results.events(): 1411*da0073e9SAndroid Build Coastguard Worker corr_id = event.correlation_id() 1412*da0073e9SAndroid Build Coastguard Worker if (corr_id) and event.device_type() == DeviceType.CPU: 1413*da0073e9SAndroid Build Coastguard Worker self.assertTrue(corr_id not in id_uniqueness_set) 1414*da0073e9SAndroid Build Coastguard Worker id_uniqueness_set.add(corr_id) 1415*da0073e9SAndroid Build Coastguard Worker self.assertTrue(corr_id < uint32_max) 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker def test_nested_tensor_with_shapes(self): 1418*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 4) 1419*da0073e9SAndroid Build Coastguard Worker b = torch.randn(4, 4) 1420*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 4) 1421*da0073e9SAndroid Build Coastguard Worker inp = torch.nested.nested_tensor([a, b]) 1422*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(record_shapes=True) as prof: 1423*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.linear(inp, c, None) 1424*da0073e9SAndroid Build Coastguard Worker for e in prof.events(): 1425*da0073e9SAndroid Build Coastguard Worker if e.name in ("aten::mm", "aten::addmm"): 1426*da0073e9SAndroid Build Coastguard Worker # intentionally vague tests to protect against possible future changes 1427*da0073e9SAndroid Build Coastguard Worker # of mm to addmm or other impl, or changing internal order of args 1428*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(e.input_shapes) > 0) 1429*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(e.input_shapes[0]) > 0) 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"}) 1432*da0073e9SAndroid Build Coastguard Worker @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"}) 1433*da0073e9SAndroid Build Coastguard Worker def test_kineto_profiler_with_environment_variable(self): 1434*da0073e9SAndroid Build Coastguard Worker script = """ 1435*da0073e9SAndroid Build Coastguard Workerimport torch 1436*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 1437*da0073e9SAndroid Build Coastguard Workerfrom torch.profiler import supported_activities, profile 1438*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler import KinetoStepTracker 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Workerclass SimpleNet(nn.Module): 1441*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1442*da0073e9SAndroid Build Coastguard Worker super().__init__() 1443*da0073e9SAndroid Build Coastguard Worker self.fc1 = nn.Linear(10, 5) 1444*da0073e9SAndroid Build Coastguard Worker self.fc2 = nn.Linear(5, 2) 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1447*da0073e9SAndroid Build Coastguard Worker return self.fc2(self.fc1(x)) 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker 1450*da0073e9SAndroid Build Coastguard Workerdef payload(use_cuda=False): 1451*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 1452*da0073e9SAndroid Build Coastguard Worker if use_cuda: 1453*da0073e9SAndroid Build Coastguard Worker x = x.cuda() 1454*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10) 1455*da0073e9SAndroid Build Coastguard Worker if use_cuda: 1456*da0073e9SAndroid Build Coastguard Worker y = y.cuda() 1457*da0073e9SAndroid Build Coastguard Worker z = torch.mm(x, y) 1458*da0073e9SAndroid Build Coastguard Worker z = z + y 1459*da0073e9SAndroid Build Coastguard Worker if use_cuda: 1460*da0073e9SAndroid Build Coastguard Worker z = z.cpu() 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Workerniters = 8 1463*da0073e9SAndroid Build Coastguard Workeruse_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 1464*da0073e9SAndroid Build Coastguard Workernet = SimpleNet() 1465*da0073e9SAndroid Build Coastguard Workeropt = torch.optim.SGD(net.parameters(), lr=0.01) 1466*da0073e9SAndroid Build Coastguard Workeropt.zero_grad() 1467*da0073e9SAndroid Build Coastguard Workerinputs = torch.rand(10) 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Workerwith profile(activities=supported_activities()): 1470*da0073e9SAndroid Build Coastguard Worker payload(use_cuda=use_cuda) 1471*da0073e9SAndroid Build Coastguard Worker 1472*da0073e9SAndroid Build Coastguard Workerinitial_step = KinetoStepTracker.current_step() 1473*da0073e9SAndroid Build Coastguard Worker 1474*da0073e9SAndroid Build Coastguard Workerdef run_batch(): 1475*da0073e9SAndroid Build Coastguard Worker out = net(inputs) 1476*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.cross_entropy(out, torch.rand(2)) 1477*da0073e9SAndroid Build Coastguard Worker loss.backward() 1478*da0073e9SAndroid Build Coastguard Worker opt.step() 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Workerfor _ in range(niters): 1481*da0073e9SAndroid Build Coastguard Worker run_batch() 1482*da0073e9SAndroid Build Coastguard Worker 1483*da0073e9SAndroid Build Coastguard Workerwith profile( 1484*da0073e9SAndroid Build Coastguard Worker activities=supported_activities(), 1485*da0073e9SAndroid Build Coastguard Worker schedule=torch.profiler.schedule( 1486*da0073e9SAndroid Build Coastguard Worker wait=1, 1487*da0073e9SAndroid Build Coastguard Worker warmup=1, 1488*da0073e9SAndroid Build Coastguard Worker active=2), 1489*da0073e9SAndroid Build Coastguard Worker) as p: 1490*da0073e9SAndroid Build Coastguard Worker for _ in range(niters): 1491*da0073e9SAndroid Build Coastguard Worker run_batch() 1492*da0073e9SAndroid Build Coastguard Worker p.step() 1493*da0073e9SAndroid Build Coastguard Workerassert KinetoStepTracker.current_step() == initial_step + 2 * niters 1494*da0073e9SAndroid Build Coastguard Worker""" 1495*da0073e9SAndroid Build Coastguard Worker try: 1496*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 1497*da0073e9SAndroid Build Coastguard Worker [sys.executable, "-W", "always", "-c", script], 1498*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)), 1499*da0073e9SAndroid Build Coastguard Worker ) 1500*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 1501*da0073e9SAndroid Build Coastguard Worker if e.returncode != 0: 1502*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1503*da0073e9SAndroid Build Coastguard Worker False, 1504*da0073e9SAndroid Build Coastguard Worker "Kineto is not working properly with the Dynolog environment variable", 1505*da0073e9SAndroid Build Coastguard Worker ) 1506*da0073e9SAndroid Build Coastguard Worker 1507*da0073e9SAndroid Build Coastguard Worker def test_concrete_inputs_profiling(self): 1508*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 6) 1509*da0073e9SAndroid Build Coastguard Worker with profile(record_shapes=True) as p: 1510*da0073e9SAndroid Build Coastguard Worker y = x.as_strided([4, 3], [1, 4]) 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker found = False 1513*da0073e9SAndroid Build Coastguard Worker for e in p.events(): 1514*da0073e9SAndroid Build Coastguard Worker if e.name in ("aten::as_strided"): 1515*da0073e9SAndroid Build Coastguard Worker found = True 1516*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(e.input_shapes) > 0) 1517*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(e.concrete_inputs) > 0) 1518*da0073e9SAndroid Build Coastguard Worker self.assertEqual([2, 6], e.input_shapes[0]) 1519*da0073e9SAndroid Build Coastguard Worker self.assertEqual([4, 3], e.concrete_inputs[1]) 1520*da0073e9SAndroid Build Coastguard Worker self.assertEqual([1, 4], e.concrete_inputs[2]) 1521*da0073e9SAndroid Build Coastguard Worker 1522*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found, "Expected to find aten::as_strided but did not") 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker def test_concrete_inputs_profiling_toggling(self): 1525*da0073e9SAndroid Build Coastguard Worker try: 1526*da0073e9SAndroid Build Coastguard Worker for before, after in [(True, False), (False, True)]: 1527*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 6) 1528*da0073e9SAndroid Build Coastguard Worker torch._C._profiler._set_record_concrete_inputs_enabled_val(before) 1529*da0073e9SAndroid Build Coastguard Worker with profile(record_shapes=True) as p: 1530*da0073e9SAndroid Build Coastguard Worker y = x.as_strided([4, 3], [1, 4]) 1531*da0073e9SAndroid Build Coastguard Worker torch._C._profiler._set_record_concrete_inputs_enabled_val(after) 1532*da0073e9SAndroid Build Coastguard Worker 1533*da0073e9SAndroid Build Coastguard Worker found = False 1534*da0073e9SAndroid Build Coastguard Worker for e in p.events(): 1535*da0073e9SAndroid Build Coastguard Worker if e.name in ("aten::as_strided"): 1536*da0073e9SAndroid Build Coastguard Worker found = True 1537*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(e.input_shapes)) 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found, "Expected to find aten::as_strided but did not") 1540*da0073e9SAndroid Build Coastguard Worker finally: 1541*da0073e9SAndroid Build Coastguard Worker torch._C._profiler._set_record_concrete_inputs_enabled_val(True) 1542*da0073e9SAndroid Build Coastguard Worker 1543*da0073e9SAndroid Build Coastguard Worker def test_record_function_fast(self): 1544*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand((4, 4)) for _ in range(2)) 1545*da0073e9SAndroid Build Coastguard Worker with profile(record_shapes=True) as p: 1546*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 1547*da0073e9SAndroid Build Coastguard Worker # Test first with no optional args 1548*da0073e9SAndroid Build Coastguard Worker with torch._C._profiler._RecordFunctionFast("add_test_fast_rf1"): 1549*da0073e9SAndroid Build Coastguard Worker x.add(y) 1550*da0073e9SAndroid Build Coastguard Worker 1551*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1552*da0073e9SAndroid Build Coastguard Worker len([e for e in p.events() if e.name == "add_test_fast_rf1"]), 4 1553*da0073e9SAndroid Build Coastguard Worker ) 1554*da0073e9SAndroid Build Coastguard Worker for e in p.events(): 1555*da0073e9SAndroid Build Coastguard Worker if e.name == "add_test_fast_rf1": 1556*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e.input_shapes == []) 1557*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e.kwinputs == {}) 1558*da0073e9SAndroid Build Coastguard Worker with profile(record_shapes=True) as p: 1559*da0073e9SAndroid Build Coastguard Worker # add optional args 1560*da0073e9SAndroid Build Coastguard Worker cm = torch._C._profiler._RecordFunctionFast( 1561*da0073e9SAndroid Build Coastguard Worker "add_test_fast_rf2", [x, y], {"stream": 0, "grid": "lambda x : x + 1"} 1562*da0073e9SAndroid Build Coastguard Worker ) 1563*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 1564*da0073e9SAndroid Build Coastguard Worker with cm: 1565*da0073e9SAndroid Build Coastguard Worker x.add(y) 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1568*da0073e9SAndroid Build Coastguard Worker len([e for e in p.events() if e.name == "add_test_fast_rf2"]), 4 1569*da0073e9SAndroid Build Coastguard Worker ) 1570*da0073e9SAndroid Build Coastguard Worker 1571*da0073e9SAndroid Build Coastguard Worker for e in p.events(): 1572*da0073e9SAndroid Build Coastguard Worker if e.name == "add_test_fast_rf2": 1573*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e.input_shapes == [[4, 4], [4, 4]]) 1574*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e.kwinputs == {"stream": 0, "grid": "lambda x : x + 1"}) 1575*da0073e9SAndroid Build Coastguard Worker 1576*da0073e9SAndroid Build Coastguard Worker with profile(record_shapes=True) as p: 1577*da0073e9SAndroid Build Coastguard Worker cm = torch._C._profiler._RecordFunctionFast( 1578*da0073e9SAndroid Build Coastguard Worker "add_test_fast_rf3", input_values=["hi"], keyword_values={"hi": "hello"} 1579*da0073e9SAndroid Build Coastguard Worker ) 1580*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 1581*da0073e9SAndroid Build Coastguard Worker try: 1582*da0073e9SAndroid Build Coastguard Worker with cm: 1583*da0073e9SAndroid Build Coastguard Worker x.add(y) 1584*da0073e9SAndroid Build Coastguard Worker raise ValueError 1585*da0073e9SAndroid Build Coastguard Worker x.relu() 1586*da0073e9SAndroid Build Coastguard Worker except ValueError: 1587*da0073e9SAndroid Build Coastguard Worker pass 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1590*da0073e9SAndroid Build Coastguard Worker len([e for e in p.events() if e.name == "add_test_fast_rf3"]), 4 1591*da0073e9SAndroid Build Coastguard Worker ) 1592*da0073e9SAndroid Build Coastguard Worker self.assertFalse(any((e.name and "relu" in e.name) for e in p.events())) 1593*da0073e9SAndroid Build Coastguard Worker 1594*da0073e9SAndroid Build Coastguard Worker for e in p.events(): 1595*da0073e9SAndroid Build Coastguard Worker if e.name == "add_test_fast_rf3": 1596*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e.input_shapes == [[]]) 1597*da0073e9SAndroid Build Coastguard Worker 1598*da0073e9SAndroid Build Coastguard Worker with profile() as p: 1599*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 1600*da0073e9SAndroid Build Coastguard Worker with torch._C._profiler._RecordFunctionFast( 1601*da0073e9SAndroid Build Coastguard Worker "add_test_fast_rf4", [x, y] 1602*da0073e9SAndroid Build Coastguard Worker ): 1603*da0073e9SAndroid Build Coastguard Worker x.add(y) 1604*da0073e9SAndroid Build Coastguard Worker with torch._C._profiler._RecordFunctionFast("add_test_fast_rf5"): 1605*da0073e9SAndroid Build Coastguard Worker x.relu() 1606*da0073e9SAndroid Build Coastguard Worker 1607*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1608*da0073e9SAndroid Build Coastguard Worker len([e for e in p.events() if e.name == "add_test_fast_rf4"]), 4 1609*da0073e9SAndroid Build Coastguard Worker ) 1610*da0073e9SAndroid Build Coastguard Worker 1611*da0073e9SAndroid Build Coastguard Worker for e in p.events(): 1612*da0073e9SAndroid Build Coastguard Worker if e.name == "add_test_fast_rf4": 1613*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e.input_shapes == []) 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1616*da0073e9SAndroid Build Coastguard Worker len([e for e in p.events() if e.name == "add_test_fast_rf5"]), 4 1617*da0073e9SAndroid Build Coastguard Worker ) 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Worker with profile(record_shapes=True) as p: 1620*da0073e9SAndroid Build Coastguard Worker # test optional args with tuple 1621*da0073e9SAndroid Build Coastguard Worker cm = torch._C._profiler._RecordFunctionFast( 1622*da0073e9SAndroid Build Coastguard Worker "add_test_fast_rf6", 1623*da0073e9SAndroid Build Coastguard Worker ( 1624*da0073e9SAndroid Build Coastguard Worker x, 1625*da0073e9SAndroid Build Coastguard Worker y, 1626*da0073e9SAndroid Build Coastguard Worker ), 1627*da0073e9SAndroid Build Coastguard Worker ) 1628*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 1629*da0073e9SAndroid Build Coastguard Worker with cm: 1630*da0073e9SAndroid Build Coastguard Worker x.add(y) 1631*da0073e9SAndroid Build Coastguard Worker 1632*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1633*da0073e9SAndroid Build Coastguard Worker len([e for e in p.events() if e.name == "add_test_fast_rf6"]), 4 1634*da0073e9SAndroid Build Coastguard Worker ) 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Worker for e in p.events(): 1637*da0073e9SAndroid Build Coastguard Worker if e.name == "add_test_fast_rf6": 1638*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e.input_shapes == [[4, 4], [4, 4]]) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("profiler gets ignored if dynamo activated") 1641*da0073e9SAndroid Build Coastguard Worker def test_profiler_op_event_kwargs(self): 1642*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand((4, 4)) for _ in range(2)) 1643*da0073e9SAndroid Build Coastguard Worker with profile(record_shapes=True) as p: 1644*da0073e9SAndroid Build Coastguard Worker cm = torch._C._profiler._RecordFunctionFast( 1645*da0073e9SAndroid Build Coastguard Worker "add_test_kwinputs", 1646*da0073e9SAndroid Build Coastguard Worker [x, y], 1647*da0073e9SAndroid Build Coastguard Worker {"stream": 0, "grid": "lambda x : x + 1", "debug": 'debug"'}, 1648*da0073e9SAndroid Build Coastguard Worker ) 1649*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 1650*da0073e9SAndroid Build Coastguard Worker with cm: 1651*da0073e9SAndroid Build Coastguard Worker x.add(y) 1652*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1653*da0073e9SAndroid Build Coastguard Worker p.export_chrome_trace(fname) 1654*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1655*da0073e9SAndroid Build Coastguard Worker j = json.load(f) 1656*da0073e9SAndroid Build Coastguard Worker op_events = [ 1657*da0073e9SAndroid Build Coastguard Worker e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op" 1658*da0073e9SAndroid Build Coastguard Worker ] 1659*da0073e9SAndroid Build Coastguard Worker for e in op_events: 1660*da0073e9SAndroid Build Coastguard Worker if e["name"] == "add_test_kwinputs": 1661*da0073e9SAndroid Build Coastguard Worker args = e["args"] 1662*da0073e9SAndroid Build Coastguard Worker self.assertTrue("stream" in args) 1663*da0073e9SAndroid Build Coastguard Worker self.assertTrue("grid" in args) 1664*da0073e9SAndroid Build Coastguard Worker self.assertTrue(args["stream"] == 0) 1665*da0073e9SAndroid Build Coastguard Worker self.assertTrue(args["grid"] == "lambda x : x + 1") 1666*da0073e9SAndroid Build Coastguard Worker self.assertTrue(args["debug"] == "None") 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Worker def test_is_profiler_enabled(self): 1669*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.profiler._is_profiler_enabled) 1670*da0073e9SAndroid Build Coastguard Worker 1671*da0073e9SAndroid Build Coastguard Worker with profile() as p: 1672*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.profiler._is_profiler_enabled) 1673*da0073e9SAndroid Build Coastguard Worker 1674*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.profiler._is_profiler_enabled) 1675*da0073e9SAndroid Build Coastguard Worker 1676*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.profile() as p: 1677*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.profiler._is_profiler_enabled) 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.profiler._is_profiler_enabled) 1680*da0073e9SAndroid Build Coastguard Worker 1681*da0073e9SAndroid Build Coastguard Worker def test_guarded_record_function_fast(self): 1682*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand((4, 4)) for _ in range(2)) 1683*da0073e9SAndroid Build Coastguard Worker 1684*da0073e9SAndroid Build Coastguard Worker with profile() as p: 1685*da0073e9SAndroid Build Coastguard Worker cm = torch._C._profiler._RecordFunctionFast("guarded_rff") 1686*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 1687*da0073e9SAndroid Build Coastguard Worker if torch.autograd.profiler._is_profiler_enabled: 1688*da0073e9SAndroid Build Coastguard Worker with cm: 1689*da0073e9SAndroid Build Coastguard Worker x.add(y) 1690*da0073e9SAndroid Build Coastguard Worker else: 1691*da0073e9SAndroid Build Coastguard Worker x.add(y) 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1694*da0073e9SAndroid Build Coastguard Worker len([e for e in p.events() if e.name == "guarded_rff"]), 4 1695*da0073e9SAndroid Build Coastguard Worker ) 1696*da0073e9SAndroid Build Coastguard Worker 1697*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") 1698*da0073e9SAndroid Build Coastguard Worker def test_event_list(self): 1699*da0073e9SAndroid Build Coastguard Worker # AFAIK event list is part of legacy profiler and/or used when kineto is not available. 1700*da0073e9SAndroid Build Coastguard Worker # This test has basic sanity checks to test against obvious regressions. 1701*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand((4, 4), requires_grad=True, device="cuda") for _ in range(2)) 1702*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True) as p: 1703*da0073e9SAndroid Build Coastguard Worker z = (x @ y).relu().sum() 1704*da0073e9SAndroid Build Coastguard Worker z.backward() 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Worker event_list = torch.autograd.profiler_util.EventList(p.events()) 1707*da0073e9SAndroid Build Coastguard Worker # event_list._build_tree() 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1710*da0073e9SAndroid Build Coastguard Worker event_list.export_chrome_trace(fname) 1711*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1712*da0073e9SAndroid Build Coastguard Worker json.load(f) 1713*da0073e9SAndroid Build Coastguard Worker 1714*da0073e9SAndroid Build Coastguard Worker event_list.table() 1715*da0073e9SAndroid Build Coastguard Worker 1716*da0073e9SAndroid Build Coastguard Worker def _check_all_gpu_present(self, gpu_dict, max_gpu_count): 1717*da0073e9SAndroid Build Coastguard Worker for i in range(0, max_gpu_count): 1718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gpu_dict["GPU " + str(i)], 1) 1719*da0073e9SAndroid Build Coastguard Worker 1720*da0073e9SAndroid Build Coastguard Worker # Do json sanity testing. Checks that all events are between profiler start and end 1721*da0073e9SAndroid Build Coastguard Worker # also checks to see that GPU values are present in trace if cuda is used 1722*da0073e9SAndroid Build Coastguard Worker def _validate_basic_json(self, traceEvents, cuda_available=False): 1723*da0073e9SAndroid Build Coastguard Worker MAX_GPU_COUNT = 8 1724*da0073e9SAndroid Build Coastguard Worker PROFILER_IDX = -4 1725*da0073e9SAndroid Build Coastguard Worker RECORD_END = -1 1726*da0073e9SAndroid Build Coastguard Worker RECORD_START = -2 1727*da0073e9SAndroid Build Coastguard Worker traceEventProfiler = traceEvents[PROFILER_IDX] 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker self.assertTrue(traceEventProfiler["name"] == "PyTorch Profiler (0)") 1730*da0073e9SAndroid Build Coastguard Worker self.assertTrue(traceEvents[RECORD_END]["name"] == "Record Window End") 1731*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1732*da0073e9SAndroid Build Coastguard Worker traceEvents[RECORD_START]["name"] == "Iteration Start: PyTorch Profiler" 1733*da0073e9SAndroid Build Coastguard Worker ) 1734*da0073e9SAndroid Build Coastguard Worker # check that the profiler starts/ends within the record interval 1735*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1736*da0073e9SAndroid Build Coastguard Worker traceEventProfiler["ts"], 1737*da0073e9SAndroid Build Coastguard Worker traceEvents[RECORD_START]["ts"], 1738*da0073e9SAndroid Build Coastguard Worker "Profiler starts before record!", 1739*da0073e9SAndroid Build Coastguard Worker ) 1740*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual( 1741*da0073e9SAndroid Build Coastguard Worker traceEventProfiler["ts"] + traceEventProfiler["dur"], 1742*da0073e9SAndroid Build Coastguard Worker traceEvents[RECORD_END]["ts"], 1743*da0073e9SAndroid Build Coastguard Worker "Profiler ends after record end!", 1744*da0073e9SAndroid Build Coastguard Worker ) 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker gpu_dict = collections.defaultdict(int) 1747*da0073e9SAndroid Build Coastguard Worker for i, traceEvent in enumerate(traceEvents): 1748*da0073e9SAndroid Build Coastguard Worker if ( 1749*da0073e9SAndroid Build Coastguard Worker i == len(traceEvents) + RECORD_END 1750*da0073e9SAndroid Build Coastguard Worker or i == len(traceEvents) + RECORD_START 1751*da0073e9SAndroid Build Coastguard Worker ): 1752*da0073e9SAndroid Build Coastguard Worker continue 1753*da0073e9SAndroid Build Coastguard Worker # make sure all valid trace events are within the bounds of the profiler 1754*da0073e9SAndroid Build Coastguard Worker if "ts" in traceEvent: 1755*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1756*da0073e9SAndroid Build Coastguard Worker traceEvent["ts"], 1757*da0073e9SAndroid Build Coastguard Worker traceEventProfiler["ts"], 1758*da0073e9SAndroid Build Coastguard Worker "Trace event is out of bounds", 1759*da0073e9SAndroid Build Coastguard Worker ) 1760*da0073e9SAndroid Build Coastguard Worker # some python events seem to go a little past record end probably because 1761*da0073e9SAndroid Build Coastguard Worker # of some clock inaccuracies so just compare events ending to RECORD_END 1762*da0073e9SAndroid Build Coastguard Worker if "dur" in traceEvent: 1763*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual( 1764*da0073e9SAndroid Build Coastguard Worker traceEvent["ts"] + traceEvent["dur"], 1765*da0073e9SAndroid Build Coastguard Worker traceEvents[RECORD_END]["ts"], 1766*da0073e9SAndroid Build Coastguard Worker "Trace event ends too late!", 1767*da0073e9SAndroid Build Coastguard Worker ) 1768*da0073e9SAndroid Build Coastguard Worker gpu_value = traceEvent.get("args", {}).get("labels", None) 1769*da0073e9SAndroid Build Coastguard Worker if gpu_value and "GPU" in gpu_value: 1770*da0073e9SAndroid Build Coastguard Worker gpu_dict[gpu_value] += 1 1771*da0073e9SAndroid Build Coastguard Worker # Max PID offset is 5M, based from pytorch/kineto include header: 1772*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/kineto/blob/8681ff11e1fa54da39023076c5c43eddd87b7a8a/libkineto/include/output_base.h#L35 1773*da0073e9SAndroid Build Coastguard Worker kExceedMaxPid = 5000000 1774*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1775*da0073e9SAndroid Build Coastguard Worker traceEvents[i + 1]["args"]["sort_index"] 1776*da0073e9SAndroid Build Coastguard Worker == kExceedMaxPid + int(gpu_value.split()[1]) 1777*da0073e9SAndroid Build Coastguard Worker ) 1778*da0073e9SAndroid Build Coastguard Worker 1779*da0073e9SAndroid Build Coastguard Worker # TODO add checking gpu count if cpuOnly_ is true or not 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Worker def _test_chrome_trace_basic_helper(self, with_cuda=False): 1782*da0073e9SAndroid Build Coastguard Worker if with_cuda: 1783*da0073e9SAndroid Build Coastguard Worker device = "cuda" 1784*da0073e9SAndroid Build Coastguard Worker else: 1785*da0073e9SAndroid Build Coastguard Worker device = "cpu" 1786*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand(4, 4).to(device) for _ in range(2)) 1787*da0073e9SAndroid Build Coastguard Worker 1788*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True) as p: 1789*da0073e9SAndroid Build Coastguard Worker torch.add(x, y) 1790*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1791*da0073e9SAndroid Build Coastguard Worker p.export_chrome_trace(fname) 1792*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1793*da0073e9SAndroid Build Coastguard Worker report = json.load(f) 1794*da0073e9SAndroid Build Coastguard Worker self._validate_basic_json(report["traceEvents"], with_cuda) 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not kineto_available(), "Kineto is required") 1797*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("profiler gets ignored if dynamo activated") 1798*da0073e9SAndroid Build Coastguard Worker def test_basic_chrome_trace(self): 1799*da0073e9SAndroid Build Coastguard Worker self._test_chrome_trace_basic_helper() 1800*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 1801*da0073e9SAndroid Build Coastguard Worker self._test_chrome_trace_basic_helper(with_cuda=True) 1802*da0073e9SAndroid Build Coastguard Worker 1803*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("profiler gets ignored if dynamo activated") 1804*da0073e9SAndroid Build Coastguard Worker def test_profiler_time_scale(self): 1805*da0073e9SAndroid Build Coastguard Worker MARGIN_ERROR = 0.5 1806*da0073e9SAndroid Build Coastguard Worker SEC_TO_US = 1000 * 1000 1807*da0073e9SAndroid Build Coastguard Worker WAIT_TIME = 10 1808*da0073e9SAndroid Build Coastguard Worker with profile() as p: 1809*da0073e9SAndroid Build Coastguard Worker with torch.profiler.record_function("test_span"): 1810*da0073e9SAndroid Build Coastguard Worker for i in range(WAIT_TIME): 1811*da0073e9SAndroid Build Coastguard Worker torch.rand(4, 4) 1812*da0073e9SAndroid Build Coastguard Worker time.sleep(1) 1813*da0073e9SAndroid Build Coastguard Worker events = p.events() 1814*da0073e9SAndroid Build Coastguard Worker 1815*da0073e9SAndroid Build Coastguard Worker # make sure function events are scaled appropriately 1816*da0073e9SAndroid Build Coastguard Worker self.assertTrue(events[0].name == "test_span") 1817*da0073e9SAndroid Build Coastguard Worker test_span = events[0] 1818*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1819*da0073e9SAndroid Build Coastguard Worker test_span.cpu_time / SEC_TO_US, 1820*da0073e9SAndroid Build Coastguard Worker WAIT_TIME - MARGIN_ERROR, 1821*da0073e9SAndroid Build Coastguard Worker "event out of range", 1822*da0073e9SAndroid Build Coastguard Worker ) 1823*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual( 1824*da0073e9SAndroid Build Coastguard Worker test_span.cpu_time / SEC_TO_US, 1825*da0073e9SAndroid Build Coastguard Worker WAIT_TIME + MARGIN_ERROR, 1826*da0073e9SAndroid Build Coastguard Worker "event out of range", 1827*da0073e9SAndroid Build Coastguard Worker ) 1828*da0073e9SAndroid Build Coastguard Worker 1829*da0073e9SAndroid Build Coastguard Worker # make sure tracing is scaled appropriately 1830*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1831*da0073e9SAndroid Build Coastguard Worker p.export_chrome_trace(fname) 1832*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1833*da0073e9SAndroid Build Coastguard Worker report = json.load(f) 1834*da0073e9SAndroid Build Coastguard Worker events = report["traceEvents"] 1835*da0073e9SAndroid Build Coastguard Worker for event in events: 1836*da0073e9SAndroid Build Coastguard Worker if event["name"] == "test_span": 1837*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual( 1838*da0073e9SAndroid Build Coastguard Worker event["dur"] / SEC_TO_US, 1839*da0073e9SAndroid Build Coastguard Worker WAIT_TIME - MARGIN_ERROR, 1840*da0073e9SAndroid Build Coastguard Worker "profiling out of range", 1841*da0073e9SAndroid Build Coastguard Worker ) 1842*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual( 1843*da0073e9SAndroid Build Coastguard Worker event["dur"] / SEC_TO_US, 1844*da0073e9SAndroid Build Coastguard Worker WAIT_TIME + MARGIN_ERROR, 1845*da0073e9SAndroid Build Coastguard Worker "profiling out of range", 1846*da0073e9SAndroid Build Coastguard Worker ) 1847*da0073e9SAndroid Build Coastguard Worker 1848*da0073e9SAndroid Build Coastguard Worker def _schedule_helper(self, warmup, active, repeat, acc_events=True): 1849*da0073e9SAndroid Build Coastguard Worker with profile( 1850*da0073e9SAndroid Build Coastguard Worker schedule=torch.profiler.schedule( 1851*da0073e9SAndroid Build Coastguard Worker skip_first=0, 1852*da0073e9SAndroid Build Coastguard Worker wait=0, 1853*da0073e9SAndroid Build Coastguard Worker warmup=warmup, 1854*da0073e9SAndroid Build Coastguard Worker active=active, 1855*da0073e9SAndroid Build Coastguard Worker repeat=repeat, 1856*da0073e9SAndroid Build Coastguard Worker ), 1857*da0073e9SAndroid Build Coastguard Worker acc_events=acc_events, 1858*da0073e9SAndroid Build Coastguard Worker ) as prof: 1859*da0073e9SAndroid Build Coastguard Worker for i in range(100): 1860*da0073e9SAndroid Build Coastguard Worker torch.add(1, 2) 1861*da0073e9SAndroid Build Coastguard Worker prof.step() 1862*da0073e9SAndroid Build Coastguard Worker # print(prof.key_averages()) 1863*da0073e9SAndroid Build Coastguard Worker for ev in prof.key_averages(): 1864*da0073e9SAndroid Build Coastguard Worker if ev.key == "aten::add": 1865*da0073e9SAndroid Build Coastguard Worker return ev.count 1866*da0073e9SAndroid Build Coastguard Worker return 0 1867*da0073e9SAndroid Build Coastguard Worker 1868*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("profiler gets ignored if dynamo activated") 1869*da0073e9SAndroid Build Coastguard Worker def test_schedule_function_count(self): 1870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self._schedule_helper(warmup=0, active=1, repeat=1), 1) 1871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=0), 100) 1872*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=10), 50) 1873*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self._schedule_helper(warmup=1, active=5, repeat=0), 83) 1874*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self._schedule_helper(warmup=10, active=10, repeat=4), 40) 1875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(self._schedule_helper(warmup=50, active=1, repeat=0), 1) 1876*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1877*da0073e9SAndroid Build Coastguard Worker self._schedule_helper(warmup=0, active=5, repeat=0, acc_events=False), 0 1878*da0073e9SAndroid Build Coastguard Worker ) 1879*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1880*da0073e9SAndroid Build Coastguard Worker self._schedule_helper(warmup=10, active=10, repeat=4, acc_events=False), 10 1881*da0073e9SAndroid Build Coastguard Worker ) 1882*da0073e9SAndroid Build Coastguard Worker 1883*da0073e9SAndroid Build Coastguard Worker def _step_helper_func(self, prof): 1884*da0073e9SAndroid Build Coastguard Worker time.sleep(0.1) 1885*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 3, 224, 224) 1886*da0073e9SAndroid Build Coastguard Worker prof.step() 1887*da0073e9SAndroid Build Coastguard Worker 1888*da0073e9SAndroid Build Coastguard Worker def _partial_overlap(self, prof_step, step_helper_func): 1889*da0073e9SAndroid Build Coastguard Worker p_start = prof_step["ts"] 1890*da0073e9SAndroid Build Coastguard Worker p_end = prof_step["ts"] + prof_step["dur"] 1891*da0073e9SAndroid Build Coastguard Worker h_start = step_helper_func["ts"] 1892*da0073e9SAndroid Build Coastguard Worker h_end = step_helper_func["ts"] + step_helper_func["dur"] 1893*da0073e9SAndroid Build Coastguard Worker 1894*da0073e9SAndroid Build Coastguard Worker if p_start < h_start and p_end < h_end and p_end > h_start: 1895*da0073e9SAndroid Build Coastguard Worker return True 1896*da0073e9SAndroid Build Coastguard Worker if p_start > h_start and p_start < h_end and p_end > h_end: 1897*da0073e9SAndroid Build Coastguard Worker return True 1898*da0073e9SAndroid Build Coastguard Worker return False 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("profiler gets ignored if dynamo activated") 1901*da0073e9SAndroid Build Coastguard Worker def test_cpu_annotation_overlap(self): 1902*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile( 1903*da0073e9SAndroid Build Coastguard Worker activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 1904*da0073e9SAndroid Build Coastguard Worker record_shapes=True, 1905*da0073e9SAndroid Build Coastguard Worker with_stack=True, 1906*da0073e9SAndroid Build Coastguard Worker schedule=torch.profiler.schedule(wait=0, warmup=0, active=5, repeat=1), 1907*da0073e9SAndroid Build Coastguard Worker ) as prof: 1908*da0073e9SAndroid Build Coastguard Worker for i in range(5): 1909*da0073e9SAndroid Build Coastguard Worker self._step_helper_func(prof) 1910*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName(mode="w+") as fname: 1911*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(fname) 1912*da0073e9SAndroid Build Coastguard Worker prof_steps = [] 1913*da0073e9SAndroid Build Coastguard Worker step_helper_funcs = [] 1914*da0073e9SAndroid Build Coastguard Worker with open(fname) as f: 1915*da0073e9SAndroid Build Coastguard Worker report = json.load(f) 1916*da0073e9SAndroid Build Coastguard Worker for event in report["traceEvents"]: 1917*da0073e9SAndroid Build Coastguard Worker if "ProfilerStep" in event["name"]: 1918*da0073e9SAndroid Build Coastguard Worker prof_steps.append(event) 1919*da0073e9SAndroid Build Coastguard Worker if "step_helper_func" in event["name"]: 1920*da0073e9SAndroid Build Coastguard Worker step_helper_funcs.append(event) 1921*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(prof_steps), 5) 1922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(step_helper_funcs), 5) 1923*da0073e9SAndroid Build Coastguard Worker for i in range(0, len(step_helper_funcs)): 1924*da0073e9SAndroid Build Coastguard Worker for j in range(0, len(step_helper_funcs)): 1925*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1926*da0073e9SAndroid Build Coastguard Worker not self._partial_overlap(prof_steps[i], step_helper_funcs[j]) 1927*da0073e9SAndroid Build Coastguard Worker ) 1928*da0073e9SAndroid Build Coastguard Worker 1929*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("profiler gets ignored if dynamo activated") 1930*da0073e9SAndroid Build Coastguard Worker def test_user_annotation(self): 1931*da0073e9SAndroid Build Coastguard Worker use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() 1932*da0073e9SAndroid Build Coastguard Worker with profile(activities=supported_activities()) as p: 1933*da0073e9SAndroid Build Coastguard Worker with torch.profiler.record_function("test_user_annotation"): 1934*da0073e9SAndroid Build Coastguard Worker self.payload(use_cuda=use_cuda) 1935*da0073e9SAndroid Build Coastguard Worker 1936*da0073e9SAndroid Build Coastguard Worker for evt in p.key_averages(): 1937*da0073e9SAndroid Build Coastguard Worker if evt.key == "test_user_annotation": 1938*da0073e9SAndroid Build Coastguard Worker self.assertTrue(evt.is_user_annotation) 1939*da0073e9SAndroid Build Coastguard Worker else: 1940*da0073e9SAndroid Build Coastguard Worker self.assertFalse(evt.is_user_annotation) 1941*da0073e9SAndroid Build Coastguard Worker 1942*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") 1943*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("profiler gets ignored if dynamo activated") 1944*da0073e9SAndroid Build Coastguard Worker def test_dynamic_toggle(self): 1945*da0073e9SAndroid Build Coastguard Worker with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p: 1946*da0073e9SAndroid Build Coastguard Worker with torch.profiler.record_function("test_user_annotation"): 1947*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand(4, 4).to("cuda") for _ in range(2)) 1948*da0073e9SAndroid Build Coastguard Worker torch.add(x, y) 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any("aten" in e.name for e in p.events())) 1951*da0073e9SAndroid Build Coastguard Worker 1952*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any("cuda" in e.name for e in p.events())) 1953*da0073e9SAndroid Build Coastguard Worker 1954*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any("kernel" in e.name for e in p.events())) 1955*da0073e9SAndroid Build Coastguard Worker 1956*da0073e9SAndroid Build Coastguard Worker with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p1: 1957*da0073e9SAndroid Build Coastguard Worker p1.toggle_collection_dynamic(False, [ProfilerActivity.CUDA]) 1958*da0073e9SAndroid Build Coastguard Worker with torch.profiler.record_function("test_user_annotation"): 1959*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand(4, 4).to("cuda") for _ in range(2)) 1960*da0073e9SAndroid Build Coastguard Worker torch.add(x, y) 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any("aten" in e.name for e in p1.events())) 1963*da0073e9SAndroid Build Coastguard Worker 1964*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all("cuda" not in e.name for e in p1.events())) 1965*da0073e9SAndroid Build Coastguard Worker 1966*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all("kernel" not in e.name for e in p1.events())) 1967*da0073e9SAndroid Build Coastguard Worker 1968*da0073e9SAndroid Build Coastguard Worker with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p2: 1969*da0073e9SAndroid Build Coastguard Worker p2.toggle_collection_dynamic( 1970*da0073e9SAndroid Build Coastguard Worker False, [ProfilerActivity.CUDA, ProfilerActivity.CPU] 1971*da0073e9SAndroid Build Coastguard Worker ) 1972*da0073e9SAndroid Build Coastguard Worker with torch.profiler.record_function("test_user_annotation"): 1973*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand(4, 4).to("cuda") for _ in range(2)) 1974*da0073e9SAndroid Build Coastguard Worker torch.add(x, y) 1975*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(p2.events()) == 0) 1976*da0073e9SAndroid Build Coastguard Worker 1977*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("profiler gets ignored if dynamo activated") 1978*da0073e9SAndroid Build Coastguard Worker def test_lazy_build_tree(self): 1979*da0073e9SAndroid Build Coastguard Worker with profile() as p: 1980*da0073e9SAndroid Build Coastguard Worker self.payload() 1981*da0073e9SAndroid Build Coastguard Worker 1982*da0073e9SAndroid Build Coastguard Worker stats = p._stats() 1983*da0073e9SAndroid Build Coastguard Worker # Test that the tree is not built 1984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stats.function_events_build_tree_call_duration_us, 0) 1985*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stats.number_of_events, 0) 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Worker # Test that the tree is built on demand 1988*da0073e9SAndroid Build Coastguard Worker p.events() 1989*da0073e9SAndroid Build Coastguard Worker self.assertGreater(stats.function_events_build_tree_call_duration_us, 0) 1990*da0073e9SAndroid Build Coastguard Worker self.assertGreater(stats.number_of_events, 0) 1991*da0073e9SAndroid Build Coastguard Worker 1992*da0073e9SAndroid Build Coastguard Worker 1993*da0073e9SAndroid Build Coastguard Workerclass SimpleNet(nn.Module): 1994*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1995*da0073e9SAndroid Build Coastguard Worker super().__init__() 1996*da0073e9SAndroid Build Coastguard Worker self.fc1 = nn.Linear(10, 5) 1997*da0073e9SAndroid Build Coastguard Worker self.fc2 = nn.Linear(5, 2) 1998*da0073e9SAndroid Build Coastguard Worker 1999*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2000*da0073e9SAndroid Build Coastguard Worker return self.fc2(self.fc1(x)) 2001*da0073e9SAndroid Build Coastguard Worker 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2004*da0073e9SAndroid Build Coastguard Workerclass MockKinetoEvent: 2005*da0073e9SAndroid Build Coastguard Worker _name: str 2006*da0073e9SAndroid Build Coastguard Worker _start_us: int 2007*da0073e9SAndroid Build Coastguard Worker _duration_us: int 2008*da0073e9SAndroid Build Coastguard Worker _linked_correlation_id: int 2009*da0073e9SAndroid Build Coastguard Worker _device_type: int 2010*da0073e9SAndroid Build Coastguard Worker 2011*da0073e9SAndroid Build Coastguard Worker @property 2012*da0073e9SAndroid Build Coastguard Worker def name(self) -> str: 2013*da0073e9SAndroid Build Coastguard Worker return self._name 2014*da0073e9SAndroid Build Coastguard Worker 2015*da0073e9SAndroid Build Coastguard Worker def start_ns(self) -> int: 2016*da0073e9SAndroid Build Coastguard Worker return self._start_us * 1000 2017*da0073e9SAndroid Build Coastguard Worker 2018*da0073e9SAndroid Build Coastguard Worker def duration_ns(self) -> int: 2019*da0073e9SAndroid Build Coastguard Worker return self._duration_us * 1000 2020*da0073e9SAndroid Build Coastguard Worker 2021*da0073e9SAndroid Build Coastguard Worker def linked_correlation_id(self) -> int: 2022*da0073e9SAndroid Build Coastguard Worker return self._linked_correlation_id 2023*da0073e9SAndroid Build Coastguard Worker 2024*da0073e9SAndroid Build Coastguard Worker def device_type(self) -> DeviceType: 2025*da0073e9SAndroid Build Coastguard Worker return DeviceType.CUDA if self._device_type == 1 else DeviceType.CPU 2026*da0073e9SAndroid Build Coastguard Worker 2027*da0073e9SAndroid Build Coastguard Worker 2028*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 2029*da0073e9SAndroid Build Coastguard Workerclass MockProfilerEvent: 2030*da0073e9SAndroid Build Coastguard Worker _name: str 2031*da0073e9SAndroid Build Coastguard Worker id: int 2032*da0073e9SAndroid Build Coastguard Worker start_time_ns: int 2033*da0073e9SAndroid Build Coastguard Worker duration_time_ns: int 2034*da0073e9SAndroid Build Coastguard Worker correlation_id: int = 0 2035*da0073e9SAndroid Build Coastguard Worker children: List["MockProfilerEvent"] = field(default_factory=list) 2036*da0073e9SAndroid Build Coastguard Worker parent: Optional["MockProfilerEvent"] = None 2037*da0073e9SAndroid Build Coastguard Worker 2038*da0073e9SAndroid Build Coastguard Worker @property 2039*da0073e9SAndroid Build Coastguard Worker def end_time_ns(self): 2040*da0073e9SAndroid Build Coastguard Worker return self.start_time_ns + self.duration_time_ns 2041*da0073e9SAndroid Build Coastguard Worker 2042*da0073e9SAndroid Build Coastguard Worker @property 2043*da0073e9SAndroid Build Coastguard Worker def name(self) -> str: 2044*da0073e9SAndroid Build Coastguard Worker return self._name 2045*da0073e9SAndroid Build Coastguard Worker 2046*da0073e9SAndroid Build Coastguard Worker def __post__init__(self, parent, children): 2047*da0073e9SAndroid Build Coastguard Worker object.__setattr__(self, "parent", parent) 2048*da0073e9SAndroid Build Coastguard Worker object.__setattr__(self, "children", children) 2049*da0073e9SAndroid Build Coastguard Worker 2050*da0073e9SAndroid Build Coastguard Worker 2051*da0073e9SAndroid Build Coastguard Workerclass MockNode: 2052*da0073e9SAndroid Build Coastguard Worker def __init__(self, name, children) -> None: 2053*da0073e9SAndroid Build Coastguard Worker self.name = name 2054*da0073e9SAndroid Build Coastguard Worker self.children = [MockNode(name, i) for name, i in children.items()] 2055*da0073e9SAndroid Build Coastguard Worker 2056*da0073e9SAndroid Build Coastguard Worker 2057*da0073e9SAndroid Build Coastguard Workerclass TestExperimentalUtils(TestCase): 2058*da0073e9SAndroid Build Coastguard Worker def make_tree(self) -> List[MockNode]: 2059*da0073e9SAndroid Build Coastguard Worker tree = { 2060*da0073e9SAndroid Build Coastguard Worker "root_0": { 2061*da0073e9SAndroid Build Coastguard Worker "1": {"2": {}}, 2062*da0073e9SAndroid Build Coastguard Worker "3": { 2063*da0073e9SAndroid Build Coastguard Worker "4": {}, 2064*da0073e9SAndroid Build Coastguard Worker "5": {}, 2065*da0073e9SAndroid Build Coastguard Worker }, 2066*da0073e9SAndroid Build Coastguard Worker }, 2067*da0073e9SAndroid Build Coastguard Worker "root_1": { 2068*da0073e9SAndroid Build Coastguard Worker "6": {}, 2069*da0073e9SAndroid Build Coastguard Worker "7": {}, 2070*da0073e9SAndroid Build Coastguard Worker "8": { 2071*da0073e9SAndroid Build Coastguard Worker "9": {"10": {}}, 2072*da0073e9SAndroid Build Coastguard Worker }, 2073*da0073e9SAndroid Build Coastguard Worker }, 2074*da0073e9SAndroid Build Coastguard Worker } 2075*da0073e9SAndroid Build Coastguard Worker return [MockNode(name, i) for name, i in tree.items()] 2076*da0073e9SAndroid Build Coastguard Worker 2077*da0073e9SAndroid Build Coastguard Worker def test_dfs(self) -> None: 2078*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2079*da0073e9SAndroid Build Coastguard Worker " ".join(i.name for i in _utils.traverse_dfs(self.make_tree())), 2080*da0073e9SAndroid Build Coastguard Worker "root_0 1 2 3 4 5 root_1 6 7 8 9 10", 2081*da0073e9SAndroid Build Coastguard Worker ) 2082*da0073e9SAndroid Build Coastguard Worker 2083*da0073e9SAndroid Build Coastguard Worker def test_bfs(self) -> None: 2084*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2085*da0073e9SAndroid Build Coastguard Worker " ".join(i.name for i in _utils.traverse_bfs(self.make_tree())), 2086*da0073e9SAndroid Build Coastguard Worker "root_0 root_1 1 3 6 7 8 2 4 5 9 10", 2087*da0073e9SAndroid Build Coastguard Worker ) 2088*da0073e9SAndroid Build Coastguard Worker 2089*da0073e9SAndroid Build Coastguard Worker @staticmethod 2090*da0073e9SAndroid Build Coastguard Worker def generate_mock_profile(): 2091*da0073e9SAndroid Build Coastguard Worker cuda_events = [ 2092*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("cudaLaunchKernel", 400, 100, 1, 0), 2093*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("cudaLaunchKernel", 500, 100, 2, 0), 2094*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("cudaLaunchKernel", 600, 100, 3, 0), 2095*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("cudaLaunchKernel", 700, 100, 4, 0), 2096*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("cudaLaunchKernel", 800, 100, 5, 0), 2097*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("cudaLaunchKernel", 1500, 100, 6, 0), 2098*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("GPU", 900, 100, 1, 1), 2099*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("GPU", 1000, 100, 2, 1), 2100*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("GPU", 1100, 100, 3, 1), 2101*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("GPU", 1200, 100, 4, 1), 2102*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("GPU", 1300, 100, 5, 1), 2103*da0073e9SAndroid Build Coastguard Worker MockKinetoEvent("GPU", 1700, 100, 6, 1), 2104*da0073e9SAndroid Build Coastguard Worker ] 2105*da0073e9SAndroid Build Coastguard Worker cpu_events = [ 2106*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (Before cudaLaunchKernel)", 1, 0, 100000), 2107*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (Before cudaLaunchKernel)", 2, 100000, 100000), 2108*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (Before cudaLaunchKernel)", 3, 200000, 100000), 2109*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (Before cudaLaunchKernel)", 4, 300000, 100000), 2110*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (After cudaLaunchKernel)", 5, 400000, 100000), 2111*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (After cudaLaunchKernel)", 6, 500000, 100000), 2112*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (After cudaLaunchKernel)", 7, 600000, 100000), 2113*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (After cudaLaunchKernel)", 8, 700000, 100000), 2114*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (After GPU)", 9, 800000, 100000), 2115*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (After GPU)", 10, 900000, 100000), 2116*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (After GPU)", 11, 1100000, 100000), 2117*da0073e9SAndroid Build Coastguard Worker MockProfilerEvent("CPU (After GPU)", 12, 1200000, 500000), 2118*da0073e9SAndroid Build Coastguard Worker ] 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker profiler = unittest.mock.Mock() 2121*da0073e9SAndroid Build Coastguard Worker profiler.kineto_results = unittest.mock.Mock() 2122*da0073e9SAndroid Build Coastguard Worker profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events) 2123*da0073e9SAndroid Build Coastguard Worker profiler.kineto_results.experimental_event_tree = unittest.mock.Mock( 2124*da0073e9SAndroid Build Coastguard Worker return_value=cpu_events 2125*da0073e9SAndroid Build Coastguard Worker ) 2126*da0073e9SAndroid Build Coastguard Worker return profiler 2127*da0073e9SAndroid Build Coastguard Worker 2128*da0073e9SAndroid Build Coastguard Worker @staticmethod 2129*da0073e9SAndroid Build Coastguard Worker def load_mock_profile(): 2130*da0073e9SAndroid Build Coastguard Worker accept = expecttest.ACCEPT 2131*da0073e9SAndroid Build Coastguard Worker json_file_path = os.path.join( 2132*da0073e9SAndroid Build Coastguard Worker os.path.dirname(os.path.realpath(__file__)), 2133*da0073e9SAndroid Build Coastguard Worker "profiler_utils_mock_events.json", 2134*da0073e9SAndroid Build Coastguard Worker ) 2135*da0073e9SAndroid Build Coastguard Worker if accept and torch.cuda.is_available(): 2136*da0073e9SAndroid Build Coastguard Worker 2137*da0073e9SAndroid Build Coastguard Worker def garbage_code(x): 2138*da0073e9SAndroid Build Coastguard Worker for i in range(5): 2139*da0073e9SAndroid Build Coastguard Worker x[0, i] = i 2140*da0073e9SAndroid Build Coastguard Worker 2141*da0073e9SAndroid Build Coastguard Worker x = torch.ones((4096, 4096), device="cuda") 2142*da0073e9SAndroid Build Coastguard Worker x = x @ x 2143*da0073e9SAndroid Build Coastguard Worker with profile( 2144*da0073e9SAndroid Build Coastguard Worker activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 2145*da0073e9SAndroid Build Coastguard Worker record_shapes=True, 2146*da0073e9SAndroid Build Coastguard Worker with_stack=True, 2147*da0073e9SAndroid Build Coastguard Worker ) as prof: 2148*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2149*da0073e9SAndroid Build Coastguard Worker x = x @ x 2150*da0073e9SAndroid Build Coastguard Worker garbage_code(x) 2151*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2152*da0073e9SAndroid Build Coastguard Worker x = x @ x 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker kineto_events = [ 2155*da0073e9SAndroid Build Coastguard Worker { 2156*da0073e9SAndroid Build Coastguard Worker "_name": e.name, 2157*da0073e9SAndroid Build Coastguard Worker "_start_ns": e.start_ns(), 2158*da0073e9SAndroid Build Coastguard Worker "_duration_ns": e.duration_ns(), 2159*da0073e9SAndroid Build Coastguard Worker "_linked_correlation_id": e.linked_correlation_id(), 2160*da0073e9SAndroid Build Coastguard Worker "_device_type": 1 if e.device_type() == DeviceType.CUDA else 0, 2161*da0073e9SAndroid Build Coastguard Worker } 2162*da0073e9SAndroid Build Coastguard Worker for e in prof.profiler.kineto_results.events() 2163*da0073e9SAndroid Build Coastguard Worker ] 2164*da0073e9SAndroid Build Coastguard Worker 2165*da0073e9SAndroid Build Coastguard Worker def EventTreeDFS(event_tree): 2166*da0073e9SAndroid Build Coastguard Worker from collections import deque 2167*da0073e9SAndroid Build Coastguard Worker 2168*da0073e9SAndroid Build Coastguard Worker stack = deque(event_tree) 2169*da0073e9SAndroid Build Coastguard Worker while stack: 2170*da0073e9SAndroid Build Coastguard Worker curr_event = stack.pop() 2171*da0073e9SAndroid Build Coastguard Worker yield curr_event 2172*da0073e9SAndroid Build Coastguard Worker for child_event in curr_event.children: 2173*da0073e9SAndroid Build Coastguard Worker stack.append(child_event) 2174*da0073e9SAndroid Build Coastguard Worker 2175*da0073e9SAndroid Build Coastguard Worker profiler_events = [ 2176*da0073e9SAndroid Build Coastguard Worker { 2177*da0073e9SAndroid Build Coastguard Worker "_name": e.name, 2178*da0073e9SAndroid Build Coastguard Worker "id": e.id, 2179*da0073e9SAndroid Build Coastguard Worker "start_time_ns": e.start_time_ns, 2180*da0073e9SAndroid Build Coastguard Worker "duration_time_ns": e.duration_time_ns, 2181*da0073e9SAndroid Build Coastguard Worker "correlation_id": e.correlation_id, 2182*da0073e9SAndroid Build Coastguard Worker "children": [child.id for child in e.children], 2183*da0073e9SAndroid Build Coastguard Worker "parent": e.parent.id if e.parent else None, 2184*da0073e9SAndroid Build Coastguard Worker } 2185*da0073e9SAndroid Build Coastguard Worker for e in EventTreeDFS( 2186*da0073e9SAndroid Build Coastguard Worker prof.profiler.kineto_results.experimental_event_tree() 2187*da0073e9SAndroid Build Coastguard Worker ) 2188*da0073e9SAndroid Build Coastguard Worker ] 2189*da0073e9SAndroid Build Coastguard Worker 2190*da0073e9SAndroid Build Coastguard Worker with open(json_file_path, "w") as f: 2191*da0073e9SAndroid Build Coastguard Worker json.dump([kineto_events, profiler_events], f) 2192*da0073e9SAndroid Build Coastguard Worker 2193*da0073e9SAndroid Build Coastguard Worker assert os.path.exists(json_file_path) 2194*da0073e9SAndroid Build Coastguard Worker with open(json_file_path) as f: 2195*da0073e9SAndroid Build Coastguard Worker kineto_events, profiler_events = json.load(f) 2196*da0073e9SAndroid Build Coastguard Worker 2197*da0073e9SAndroid Build Coastguard Worker cuda_events = [MockKinetoEvent(*event.values()) for event in kineto_events] 2198*da0073e9SAndroid Build Coastguard Worker cpu_events = [] 2199*da0073e9SAndroid Build Coastguard Worker id_map = {} 2200*da0073e9SAndroid Build Coastguard Worker for e in profiler_events: 2201*da0073e9SAndroid Build Coastguard Worker event = MockProfilerEvent(**e) 2202*da0073e9SAndroid Build Coastguard Worker id_map[event.id] = event 2203*da0073e9SAndroid Build Coastguard Worker cpu_events.append(event) 2204*da0073e9SAndroid Build Coastguard Worker for event in cpu_events: 2205*da0073e9SAndroid Build Coastguard Worker parent = None if event.parent is None else id_map[event.parent] 2206*da0073e9SAndroid Build Coastguard Worker children = [id_map[child] for child in event.children] 2207*da0073e9SAndroid Build Coastguard Worker event.__post__init__(parent, children) 2208*da0073e9SAndroid Build Coastguard Worker cpu_events = [event for event in cpu_events if event.parent is None] 2209*da0073e9SAndroid Build Coastguard Worker profiler = unittest.mock.Mock() 2210*da0073e9SAndroid Build Coastguard Worker profiler.kineto_results = unittest.mock.Mock() 2211*da0073e9SAndroid Build Coastguard Worker profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events) 2212*da0073e9SAndroid Build Coastguard Worker profiler.kineto_results.experimental_event_tree = unittest.mock.Mock( 2213*da0073e9SAndroid Build Coastguard Worker return_value=cpu_events 2214*da0073e9SAndroid Build Coastguard Worker ) 2215*da0073e9SAndroid Build Coastguard Worker return profiler 2216*da0073e9SAndroid Build Coastguard Worker 2217*da0073e9SAndroid Build Coastguard Worker def test_utils_compute_self_time(self): 2218*da0073e9SAndroid Build Coastguard Worker with profile() as prof: 2219*da0073e9SAndroid Build Coastguard Worker t1, t2 = torch.ones(1, requires_grad=True), torch.ones( 2220*da0073e9SAndroid Build Coastguard Worker 1, requires_grad=True 2221*da0073e9SAndroid Build Coastguard Worker ) 2222*da0073e9SAndroid Build Coastguard Worker z = torch.add(t1, t2) 2223*da0073e9SAndroid Build Coastguard Worker y = torch.ones(1) 2224*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) 2225*da0073e9SAndroid Build Coastguard Worker loss.backward() 2226*da0073e9SAndroid Build Coastguard Worker basic_eval = _utils.BasicEvaluation(prof.profiler) 2227*da0073e9SAndroid Build Coastguard Worker metrics = basic_eval.metrics 2228*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(metrics) > 0) 2229*da0073e9SAndroid Build Coastguard Worker for event_key, event_metrics in metrics.items(): 2230*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2231*da0073e9SAndroid Build Coastguard Worker event_metrics.self_time_ns, 2232*da0073e9SAndroid Build Coastguard Worker event_key.event.duration_time_ns 2233*da0073e9SAndroid Build Coastguard Worker - sum(child.duration_time_ns for child in event_key.event.children), 2234*da0073e9SAndroid Build Coastguard Worker ) 2235*da0073e9SAndroid Build Coastguard Worker 2236*da0073e9SAndroid Build Coastguard Worker def test_utils_intervals_overlap(self): 2237*da0073e9SAndroid Build Coastguard Worker event = _utils.EventKey(MockProfilerEvent("Event 1", 1, 5, 5)) 2238*da0073e9SAndroid Build Coastguard Worker intervals = [ 2239*da0073e9SAndroid Build Coastguard Worker _utils.Interval(0, 9), 2240*da0073e9SAndroid Build Coastguard Worker _utils.Interval(1, 2), 2241*da0073e9SAndroid Build Coastguard Worker _utils.Interval(2, 3), 2242*da0073e9SAndroid Build Coastguard Worker _utils.Interval(3, 4), 2243*da0073e9SAndroid Build Coastguard Worker _utils.Interval(4, 5), 2244*da0073e9SAndroid Build Coastguard Worker _utils.Interval(8, 12), 2245*da0073e9SAndroid Build Coastguard Worker ] 2246*da0073e9SAndroid Build Coastguard Worker print(event.intervals_overlap(intervals)) 2247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(event.intervals_overlap(intervals), 5) 2248*da0073e9SAndroid Build Coastguard Worker 2249*da0073e9SAndroid Build Coastguard Worker def test_utils_compute_queue_depth(self): 2250*da0073e9SAndroid Build Coastguard Worker def format_queue_depth(queue_depth_list, events): 2251*da0073e9SAndroid Build Coastguard Worker res = "" 2252*da0073e9SAndroid Build Coastguard Worker for data, event in zip(queue_depth_list, events): 2253*da0073e9SAndroid Build Coastguard Worker res += f"{data.queue_depth} [{event.name}]\n" 2254*da0073e9SAndroid Build Coastguard Worker return res 2255*da0073e9SAndroid Build Coastguard Worker 2256*da0073e9SAndroid Build Coastguard Worker # We have to use Mock because time series data is too flaky to test 2257*da0073e9SAndroid Build Coastguard Worker profiler = self.generate_mock_profile() 2258*da0073e9SAndroid Build Coastguard Worker basic_evaluation = _utils.BasicEvaluation(profiler) 2259*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2260*da0073e9SAndroid Build Coastguard Worker format_queue_depth( 2261*da0073e9SAndroid Build Coastguard Worker basic_evaluation.queue_depth_list, basic_evaluation.cuda_events 2262*da0073e9SAndroid Build Coastguard Worker ), 2263*da0073e9SAndroid Build Coastguard Worker """\ 2264*da0073e9SAndroid Build Coastguard Worker1 [cudaLaunchKernel] 2265*da0073e9SAndroid Build Coastguard Worker2 [cudaLaunchKernel] 2266*da0073e9SAndroid Build Coastguard Worker3 [cudaLaunchKernel] 2267*da0073e9SAndroid Build Coastguard Worker4 [cudaLaunchKernel] 2268*da0073e9SAndroid Build Coastguard Worker5 [cudaLaunchKernel] 2269*da0073e9SAndroid Build Coastguard Worker4 [GPU] 2270*da0073e9SAndroid Build Coastguard Worker3 [GPU] 2271*da0073e9SAndroid Build Coastguard Worker2 [GPU] 2272*da0073e9SAndroid Build Coastguard Worker1 [GPU] 2273*da0073e9SAndroid Build Coastguard Worker0 [GPU] 2274*da0073e9SAndroid Build Coastguard Worker1 [cudaLaunchKernel] 2275*da0073e9SAndroid Build Coastguard Worker0 [GPU] 2276*da0073e9SAndroid Build Coastguard Worker""", 2277*da0073e9SAndroid Build Coastguard Worker ) 2278*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2279*da0073e9SAndroid Build Coastguard Worker format_queue_depth( 2280*da0073e9SAndroid Build Coastguard Worker [basic_evaluation.metrics[k] for k in basic_evaluation.event_keys], 2281*da0073e9SAndroid Build Coastguard Worker basic_evaluation.events, 2282*da0073e9SAndroid Build Coastguard Worker ), 2283*da0073e9SAndroid Build Coastguard Worker """\ 2284*da0073e9SAndroid Build Coastguard Worker0 [CPU (Before cudaLaunchKernel)] 2285*da0073e9SAndroid Build Coastguard Worker0 [CPU (Before cudaLaunchKernel)] 2286*da0073e9SAndroid Build Coastguard Worker0 [CPU (Before cudaLaunchKernel)] 2287*da0073e9SAndroid Build Coastguard Worker0 [CPU (Before cudaLaunchKernel)] 2288*da0073e9SAndroid Build Coastguard Worker1 [CPU (After cudaLaunchKernel)] 2289*da0073e9SAndroid Build Coastguard Worker2 [CPU (After cudaLaunchKernel)] 2290*da0073e9SAndroid Build Coastguard Worker3 [CPU (After cudaLaunchKernel)] 2291*da0073e9SAndroid Build Coastguard Worker4 [CPU (After cudaLaunchKernel)] 2292*da0073e9SAndroid Build Coastguard Worker5 [CPU (After GPU)] 2293*da0073e9SAndroid Build Coastguard Worker4 [CPU (After GPU)] 2294*da0073e9SAndroid Build Coastguard Worker2 [CPU (After GPU)] 2295*da0073e9SAndroid Build Coastguard Worker1 [CPU (After GPU)] 2296*da0073e9SAndroid Build Coastguard Worker""", 2297*da0073e9SAndroid Build Coastguard Worker ) 2298*da0073e9SAndroid Build Coastguard Worker 2299*da0073e9SAndroid Build Coastguard Worker def test_utils_compute_queue_depth_when_no_cuda_events(self): 2300*da0073e9SAndroid Build Coastguard Worker # For traces with only cpu events, we expect empty queue depth list 2301*da0073e9SAndroid Build Coastguard Worker x = torch.ones((1024, 1024)) 2302*da0073e9SAndroid Build Coastguard Worker with profile() as prof: 2303*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2304*da0073e9SAndroid Build Coastguard Worker x = x @ x 2305*da0073e9SAndroid Build Coastguard Worker basic_evaluation = _utils.BasicEvaluation(prof.profiler) 2306*da0073e9SAndroid Build Coastguard Worker self.assertFalse(basic_evaluation.compute_queue_depth()) 2307*da0073e9SAndroid Build Coastguard Worker 2308*da0073e9SAndroid Build Coastguard Worker def test_utils_compute_idle_time(self): 2309*da0073e9SAndroid Build Coastguard Worker profiler = self.generate_mock_profile() 2310*da0073e9SAndroid Build Coastguard Worker basic_evaluation = _utils.BasicEvaluation(profiler) 2311*da0073e9SAndroid Build Coastguard Worker expected_output = "\n".join( 2312*da0073e9SAndroid Build Coastguard Worker [ 2313*da0073e9SAndroid Build Coastguard Worker f"{basic_evaluation.metrics[event_key].idle_time_ns} [{event_key.event.name}]" 2314*da0073e9SAndroid Build Coastguard Worker for event_key in basic_evaluation.event_keys 2315*da0073e9SAndroid Build Coastguard Worker ] 2316*da0073e9SAndroid Build Coastguard Worker ) 2317*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2318*da0073e9SAndroid Build Coastguard Worker expected_output, 2319*da0073e9SAndroid Build Coastguard Worker """\ 2320*da0073e9SAndroid Build Coastguard Worker100000 [CPU (Before cudaLaunchKernel)] 2321*da0073e9SAndroid Build Coastguard Worker100000 [CPU (Before cudaLaunchKernel)] 2322*da0073e9SAndroid Build Coastguard Worker100000 [CPU (Before cudaLaunchKernel)] 2323*da0073e9SAndroid Build Coastguard Worker100000 [CPU (Before cudaLaunchKernel)] 2324*da0073e9SAndroid Build Coastguard Worker0 [CPU (After cudaLaunchKernel)] 2325*da0073e9SAndroid Build Coastguard Worker0 [CPU (After cudaLaunchKernel)] 2326*da0073e9SAndroid Build Coastguard Worker0 [CPU (After cudaLaunchKernel)] 2327*da0073e9SAndroid Build Coastguard Worker0 [CPU (After cudaLaunchKernel)] 2328*da0073e9SAndroid Build Coastguard Worker0 [CPU (After GPU)] 2329*da0073e9SAndroid Build Coastguard Worker0 [CPU (After GPU)] 2330*da0073e9SAndroid Build Coastguard Worker0 [CPU (After GPU)] 2331*da0073e9SAndroid Build Coastguard Worker100000 [CPU (After GPU)]""", 2332*da0073e9SAndroid Build Coastguard Worker ) 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_JETSON, "JSON not behaving as expected on Jetson") 2335*da0073e9SAndroid Build Coastguard Worker def test_utils_get_optimizable_events(self): 2336*da0073e9SAndroid Build Coastguard Worker basic_evaluation = _utils.BasicEvaluation(self.load_mock_profile()) 2337*da0073e9SAndroid Build Coastguard Worker optimizable_events = basic_evaluation.get_optimizable_events( 2338*da0073e9SAndroid Build Coastguard Worker 2, print_enable=False 2339*da0073e9SAndroid Build Coastguard Worker ) 2340*da0073e9SAndroid Build Coastguard Worker expected_output = "\n".join( 2341*da0073e9SAndroid Build Coastguard Worker [f"{event_key.event.name}" for event_key in optimizable_events] 2342*da0073e9SAndroid Build Coastguard Worker ) 2343*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2344*da0073e9SAndroid Build Coastguard Worker expected_output, 2345*da0073e9SAndroid Build Coastguard Worker """\ 2346*da0073e9SAndroid Build Coastguard Worker<built-in function _cuda_synchronize> 2347*da0073e9SAndroid Build Coastguard Workeraten::copy_""", 2348*da0073e9SAndroid Build Coastguard Worker ) 2349*da0073e9SAndroid Build Coastguard Worker 2350*da0073e9SAndroid Build Coastguard Worker def test_profiler_name_pattern(self): 2351*da0073e9SAndroid Build Coastguard Worker x = torch.ones((4096, 4096)) 2352*da0073e9SAndroid Build Coastguard Worker with profile() as prof: 2353*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2354*da0073e9SAndroid Build Coastguard Worker x = x @ x 2355*da0073e9SAndroid Build Coastguard Worker x = x + x 2356*da0073e9SAndroid Build Coastguard Worker matched_events = NamePattern(prof, "aten::mm").matched_events() 2357*da0073e9SAndroid Build Coastguard Worker output = "\n".join([f"{event.name}" for event in matched_events]) 2358*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2359*da0073e9SAndroid Build Coastguard Worker output, 2360*da0073e9SAndroid Build Coastguard Worker """\ 2361*da0073e9SAndroid Build Coastguard Workeraten::mm 2362*da0073e9SAndroid Build Coastguard Workeraten::mm 2363*da0073e9SAndroid Build Coastguard Workeraten::mm 2364*da0073e9SAndroid Build Coastguard Workeraten::mm 2365*da0073e9SAndroid Build Coastguard Workeraten::mm""", 2366*da0073e9SAndroid Build Coastguard Worker ) 2367*da0073e9SAndroid Build Coastguard Worker 2368*da0073e9SAndroid Build Coastguard Worker # TODO: Add logic for CUDA version of test 2369*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA") 2370*da0073e9SAndroid Build Coastguard Worker def test_profiler_pattern_match_helper(self): 2371*da0073e9SAndroid Build Coastguard Worker x = torch.ones((100, 100)) 2372*da0073e9SAndroid Build Coastguard Worker with profile() as prof: 2373*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 2374*da0073e9SAndroid Build Coastguard Worker x = x @ x 2375*da0073e9SAndroid Build Coastguard Worker x = x + x 2376*da0073e9SAndroid Build Coastguard Worker event_tree = prof.profiler.kineto_results.experimental_event_tree() 2377*da0073e9SAndroid Build Coastguard Worker pattern = Pattern(prof) 2378*da0073e9SAndroid Build Coastguard Worker self.assertEqual([], pattern.siblings_of(event_tree[0])[0]) 2379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(event_tree[1:], pattern.siblings_of(event_tree[0])[1]) 2380*da0073e9SAndroid Build Coastguard Worker child_nodes = event_tree[0].children 2381*da0073e9SAndroid Build Coastguard Worker self.assertEqual([], pattern.siblings_of(child_nodes[0])[0]) 2382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(child_nodes[1:], pattern.siblings_of(child_nodes[0])[1]) 2383*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2384*da0073e9SAndroid Build Coastguard Worker event_tree[0], pattern.root_of(event_tree[0].children[0].children[0]) 2385*da0073e9SAndroid Build Coastguard Worker ) 2386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, pattern.next_of(event_tree[-1])) 2387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(event_tree[1], pattern.next_of(event_tree[0])) 2388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(event_tree[0], pattern.prev_of(event_tree[1])) 2389*da0073e9SAndroid Build Coastguard Worker 2390*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2391*da0073e9SAndroid Build Coastguard Worker TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." 2392*da0073e9SAndroid Build Coastguard Worker ) 2393*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") 2394*da0073e9SAndroid Build Coastguard Worker def test_profiler_extra_cuda_copy_pattern(self): 2395*da0073e9SAndroid Build Coastguard Worker cases = ( 2396*da0073e9SAndroid Build Coastguard Worker (0, lambda: torch.ones((100, 100), device="cuda")), 2397*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.ones((100, 100)).to("cuda")), 2398*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.zeros((100, 100)).to("cuda")), 2399*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.empty((100, 100)).fill_(5).to("cuda")), 2400*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.ones((100, 100)).cuda()), 2401*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.zeros((100, 100)).cuda()), 2402*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.empty((100, 100)).fill_(5).cuda()), 2403*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.rand((100, 100)).cuda()), 2404*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.randn((100, 100)).cuda()), 2405*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.full((100, 100), 10).cuda()), 2406*da0073e9SAndroid Build Coastguard Worker (0, lambda: torch.rand((100, 100)).to(dtype=torch.float16)), 2407*da0073e9SAndroid Build Coastguard Worker (0, lambda: torch.rand((100, 100)).half()), 2408*da0073e9SAndroid Build Coastguard Worker (0, lambda: torch.rand((100, 100), device="cuda").half()), 2409*da0073e9SAndroid Build Coastguard Worker ) 2410*da0073e9SAndroid Build Coastguard Worker num_matched = [] 2411*da0073e9SAndroid Build Coastguard Worker for _, fn in cases: 2412*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True, record_shapes=True) as prof: 2413*da0073e9SAndroid Build Coastguard Worker fn() 2414*da0073e9SAndroid Build Coastguard Worker pattern = ExtraCUDACopyPattern(prof) 2415*da0073e9SAndroid Build Coastguard Worker num_matched.append(len(pattern.matched_events())) 2416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_matched, [i for i, _ in cases]) 2417*da0073e9SAndroid Build Coastguard Worker 2418*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2419*da0073e9SAndroid Build Coastguard Worker TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." 2420*da0073e9SAndroid Build Coastguard Worker ) 2421*da0073e9SAndroid Build Coastguard Worker def test_profiler_for_loop_indexing_pattern(self): 2422*da0073e9SAndroid Build Coastguard Worker x = torch.ones((100, 100)) 2423*da0073e9SAndroid Build Coastguard Worker 2424*da0073e9SAndroid Build Coastguard Worker def case1(): 2425*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2426*da0073e9SAndroid Build Coastguard Worker x[i] = i 2427*da0073e9SAndroid Build Coastguard Worker 2428*da0073e9SAndroid Build Coastguard Worker def case2(): 2429*da0073e9SAndroid Build Coastguard Worker y = 0 2430*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2431*da0073e9SAndroid Build Coastguard Worker y += x[i] 2432*da0073e9SAndroid Build Coastguard Worker 2433*da0073e9SAndroid Build Coastguard Worker def case3(): 2434*da0073e9SAndroid Build Coastguard Worker y = 1 2435*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2436*da0073e9SAndroid Build Coastguard Worker y *= x[i] 2437*da0073e9SAndroid Build Coastguard Worker 2438*da0073e9SAndroid Build Coastguard Worker def case4(): 2439*da0073e9SAndroid Build Coastguard Worker y = x 2440*da0073e9SAndroid Build Coastguard Worker for _ in range(100): 2441*da0073e9SAndroid Build Coastguard Worker y = y @ x 2442*da0073e9SAndroid Build Coastguard Worker 2443*da0073e9SAndroid Build Coastguard Worker def case5(): 2444*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2445*da0073e9SAndroid Build Coastguard Worker x[i, :] = torch.arange(100) + i 2446*da0073e9SAndroid Build Coastguard Worker 2447*da0073e9SAndroid Build Coastguard Worker cases = ((1, case1), (1, case2), (1, case3), (0, case4), (1, case5)) 2448*da0073e9SAndroid Build Coastguard Worker num_matched = [] 2449*da0073e9SAndroid Build Coastguard Worker for _, fn in cases: 2450*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True) as prof: 2451*da0073e9SAndroid Build Coastguard Worker fn() 2452*da0073e9SAndroid Build Coastguard Worker pattern = ForLoopIndexingPattern(prof) 2453*da0073e9SAndroid Build Coastguard Worker num_matched.append(len(pattern.matched_events())) 2454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_matched, [i for i, _ in cases]) 2455*da0073e9SAndroid Build Coastguard Worker 2456*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") 2457*da0073e9SAndroid Build Coastguard Worker def test_profiler_fp32_matmul_pattern(self): 2458*da0073e9SAndroid Build Coastguard Worker x = torch.ones((100, 100), device="cuda") 2459*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True) as prof: 2460*da0073e9SAndroid Build Coastguard Worker x = x @ x 2461*da0073e9SAndroid Build Coastguard Worker pattern = FP32MatMulPattern(prof) 2462*da0073e9SAndroid Build Coastguard Worker has_tf32 = 0 if pattern.skip else 1 2463*da0073e9SAndroid Build Coastguard Worker num_matched = len(pattern.matched_events()) 2464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_matched, has_tf32) 2465*da0073e9SAndroid Build Coastguard Worker 2466*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") 2467*da0073e9SAndroid Build Coastguard Worker def test_profiler_extra_cuda_copy_pattern_benchmark(self): 2468*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True, record_shapes=True) as prof: 2469*da0073e9SAndroid Build Coastguard Worker x = torch.ones((100, 100)).to("cuda") 2470*da0073e9SAndroid Build Coastguard Worker x = torch.ones((50, 50)).to("cuda") 2471*da0073e9SAndroid Build Coastguard Worker pattern = ExtraCUDACopyPattern(prof) 2472*da0073e9SAndroid Build Coastguard Worker shapes_factor_map = pattern.benchmark(pattern.matched_events()) 2473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shapes_factor_map), 2) 2474*da0073e9SAndroid Build Coastguard Worker 2475*da0073e9SAndroid Build Coastguard Worker def test_profiler_optimizer_single_tensor_pattern(self): 2476*da0073e9SAndroid Build Coastguard Worker x = torch.ones((100, 100)) 2477*da0073e9SAndroid Build Coastguard Worker cases = ( 2478*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.optim.Adam(model.parameters())), 2479*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.optim.SGD(model.parameters(), lr=0.01)), 2480*da0073e9SAndroid Build Coastguard Worker (1, lambda: torch.optim.AdamW(model.parameters())), 2481*da0073e9SAndroid Build Coastguard Worker (0, lambda: torch.optim.Adam(model.parameters(), foreach=True)), 2482*da0073e9SAndroid Build Coastguard Worker (0, lambda: torch.optim.SGD(model.parameters(), lr=0.01, foreach=True)), 2483*da0073e9SAndroid Build Coastguard Worker (0, lambda: torch.optim.AdamW(model.parameters(), foreach=True)), 2484*da0073e9SAndroid Build Coastguard Worker ) 2485*da0073e9SAndroid Build Coastguard Worker num_matched = [] 2486*da0073e9SAndroid Build Coastguard Worker for _, fn in cases: 2487*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True) as prof: 2488*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential( 2489*da0073e9SAndroid Build Coastguard Worker nn.Linear(100, 100), 2490*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 2491*da0073e9SAndroid Build Coastguard Worker nn.Linear(100, 10), 2492*da0073e9SAndroid Build Coastguard Worker ) 2493*da0073e9SAndroid Build Coastguard Worker optimizer = fn() 2494*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 2495*da0073e9SAndroid Build Coastguard Worker y_hat = model(x) 2496*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.cross_entropy( 2497*da0073e9SAndroid Build Coastguard Worker y_hat, torch.randint(0, 10, (100,)) 2498*da0073e9SAndroid Build Coastguard Worker ) 2499*da0073e9SAndroid Build Coastguard Worker loss.backward() 2500*da0073e9SAndroid Build Coastguard Worker optimizer.step() 2501*da0073e9SAndroid Build Coastguard Worker pattern = OptimizerSingleTensorPattern(prof) 2502*da0073e9SAndroid Build Coastguard Worker num_matched.append(len(pattern.matched_events())) 2503*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_matched, [i for i, _ in cases]) 2504*da0073e9SAndroid Build Coastguard Worker 2505*da0073e9SAndroid Build Coastguard Worker def test_profiler_synchronized_dataloader_pattern(self): 2506*da0073e9SAndroid Build Coastguard Worker dataset = torch.rand((100, 100)) 2507*da0073e9SAndroid Build Coastguard Worker sync_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10) 2508*da0073e9SAndroid Build Coastguard Worker async_dataloader = torch.utils.data.DataLoader( 2509*da0073e9SAndroid Build Coastguard Worker dataset, batch_size=10, num_workers=4 2510*da0073e9SAndroid Build Coastguard Worker ) 2511*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True) as prof: 2512*da0073e9SAndroid Build Coastguard Worker next(iter(sync_dataloader)) 2513*da0073e9SAndroid Build Coastguard Worker next(iter(async_dataloader)) 2514*da0073e9SAndroid Build Coastguard Worker pattern = SynchronizedDataLoaderPattern(prof) 2515*da0073e9SAndroid Build Coastguard Worker num_matched = len(pattern.matched_events()) 2516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_matched, 1) 2517*da0073e9SAndroid Build Coastguard Worker 2518*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo( 2519*da0073e9SAndroid Build Coastguard Worker "pattern checks for aten::_zero op which might not be there with torch.compile'd graph" 2520*da0073e9SAndroid Build Coastguard Worker ) 2521*da0073e9SAndroid Build Coastguard Worker def test_profiler_grad_not_set_to_none_pattern(self): 2522*da0073e9SAndroid Build Coastguard Worker x = torch.ones((100, 100)) 2523*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential( 2524*da0073e9SAndroid Build Coastguard Worker nn.Linear(100, 100), 2525*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 2526*da0073e9SAndroid Build Coastguard Worker nn.Linear(100, 10), 2527*da0073e9SAndroid Build Coastguard Worker ) 2528*da0073e9SAndroid Build Coastguard Worker optimizer = torch.optim.Adam(model.parameters()) 2529*da0073e9SAndroid Build Coastguard Worker cases = ( 2530*da0073e9SAndroid Build Coastguard Worker (0, lambda: optimizer.zero_grad()), 2531*da0073e9SAndroid Build Coastguard Worker (0, lambda: model.zero_grad()), 2532*da0073e9SAndroid Build Coastguard Worker (1, lambda: optimizer.zero_grad(set_to_none=False)), 2533*da0073e9SAndroid Build Coastguard Worker (1, lambda: model.zero_grad(set_to_none=False)), 2534*da0073e9SAndroid Build Coastguard Worker ) 2535*da0073e9SAndroid Build Coastguard Worker num_matched = [] 2536*da0073e9SAndroid Build Coastguard Worker for _, fn in cases: 2537*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True) as prof: 2538*da0073e9SAndroid Build Coastguard Worker y_hat = model(x) 2539*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.cross_entropy( 2540*da0073e9SAndroid Build Coastguard Worker y_hat, torch.randint(0, 10, (100,)) 2541*da0073e9SAndroid Build Coastguard Worker ) 2542*da0073e9SAndroid Build Coastguard Worker loss.backward() 2543*da0073e9SAndroid Build Coastguard Worker optimizer.step() 2544*da0073e9SAndroid Build Coastguard Worker fn() 2545*da0073e9SAndroid Build Coastguard Worker pattern = GradNotSetToNonePattern(prof) 2546*da0073e9SAndroid Build Coastguard Worker num_matched.append(len(pattern.matched_events())) 2547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_matched, [i for i, _ in cases]) 2548*da0073e9SAndroid Build Coastguard Worker 2549*da0073e9SAndroid Build Coastguard Worker def test_profiler_conv2d_bias_followed_by_batchnorm2d_pattern(self): 2550*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 3, 32, 32)) 2551*da0073e9SAndroid Build Coastguard Worker cases = ( 2552*da0073e9SAndroid Build Coastguard Worker (1, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1), nn.BatchNorm2d(3))), 2553*da0073e9SAndroid Build Coastguard Worker (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1, bias=False), nn.BatchNorm2d(3))), 2554*da0073e9SAndroid Build Coastguard Worker (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1))), 2555*da0073e9SAndroid Build Coastguard Worker ) 2556*da0073e9SAndroid Build Coastguard Worker num_matched = [] 2557*da0073e9SAndroid Build Coastguard Worker for _, model in cases: 2558*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True, record_shapes=True) as prof: 2559*da0073e9SAndroid Build Coastguard Worker model(x) 2560*da0073e9SAndroid Build Coastguard Worker pattern = Conv2dBiasFollowedByBatchNorm2dPattern(prof) 2561*da0073e9SAndroid Build Coastguard Worker num_matched.append(len(pattern.matched_events())) 2562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_matched, [i for i, _ in cases]) 2563*da0073e9SAndroid Build Coastguard Worker 2564*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") 2565*da0073e9SAndroid Build Coastguard Worker def test_profiler_matmul_dim_fp16_pattern(self): 2566*da0073e9SAndroid Build Coastguard Worker cases = ( 2567*da0073e9SAndroid Build Coastguard Worker (1, torch.randn((201, 201), device="cuda", dtype=torch.float16)), 2568*da0073e9SAndroid Build Coastguard Worker (1, torch.randn((3, 97, 97), device="cuda", dtype=torch.float16)), 2569*da0073e9SAndroid Build Coastguard Worker (0, torch.randn((200, 200), device="cuda", dtype=torch.float16)), 2570*da0073e9SAndroid Build Coastguard Worker (0, torch.randn((3, 200, 200), device="cuda", dtype=torch.float16)), 2571*da0073e9SAndroid Build Coastguard Worker ) 2572*da0073e9SAndroid Build Coastguard Worker num_matched = [] 2573*da0073e9SAndroid Build Coastguard Worker for _, x in cases: 2574*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True, record_shapes=True) as prof: 2575*da0073e9SAndroid Build Coastguard Worker x @ x 2576*da0073e9SAndroid Build Coastguard Worker pattern = MatMulDimInFP16Pattern(prof) 2577*da0073e9SAndroid Build Coastguard Worker num_matched.append(len(pattern.matched_events())) 2578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_matched, [i for i, _ in cases]) 2579*da0073e9SAndroid Build Coastguard Worker 2580*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("profiler gets ignored if dynamo activated") 2581*da0073e9SAndroid Build Coastguard Worker def test_profiler_pattern_matcher_json_report(self): 2582*da0073e9SAndroid Build Coastguard Worker x = torch.ones((100, 100)) 2583*da0073e9SAndroid Build Coastguard Worker model = nn.Sequential( 2584*da0073e9SAndroid Build Coastguard Worker nn.Linear(100, 100), 2585*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 2586*da0073e9SAndroid Build Coastguard Worker nn.Linear(100, 10), 2587*da0073e9SAndroid Build Coastguard Worker ) 2588*da0073e9SAndroid Build Coastguard Worker optimizer = torch.optim.Adam(model.parameters()) 2589*da0073e9SAndroid Build Coastguard Worker with profile(with_stack=True, record_shapes=True) as prof: 2590*da0073e9SAndroid Build Coastguard Worker y_hat = model(x) 2591*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.functional.cross_entropy( 2592*da0073e9SAndroid Build Coastguard Worker y_hat, torch.randint(0, 10, (100,)) 2593*da0073e9SAndroid Build Coastguard Worker ) 2594*da0073e9SAndroid Build Coastguard Worker loss.backward() 2595*da0073e9SAndroid Build Coastguard Worker optimizer.step() 2596*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 2597*da0073e9SAndroid Build Coastguard Worker 2598*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmpdir: 2599*da0073e9SAndroid Build Coastguard Worker report_all_anti_patterns(prof, json_report_dir=tmpdir, print_enable=False) 2600*da0073e9SAndroid Build Coastguard Worker 2601*da0073e9SAndroid Build Coastguard Worker with open(os.path.join(tmpdir, "torchtidy_report.json")) as f: 2602*da0073e9SAndroid Build Coastguard Worker report = json.load(f) 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker # It is platform dependent whether the path will include "profiler/" 2605*da0073e9SAndroid Build Coastguard Worker keys = [k for k in report.keys() if k.endswith("test_profiler.py")] 2606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(keys), 1, f"{keys}") 2607*da0073e9SAndroid Build Coastguard Worker entry = report[keys[0]] 2608*da0073e9SAndroid Build Coastguard Worker 2609*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(entry) > 0) 2610*da0073e9SAndroid Build Coastguard Worker expected_fields = sorted(["line_number", "name", "url", "message"]) 2611*da0073e9SAndroid Build Coastguard Worker for event in entry: 2612*da0073e9SAndroid Build Coastguard Worker actual_fields = sorted(event.keys()) 2613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_fields, actual_fields) 2614*da0073e9SAndroid Build Coastguard Worker 2615*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") 2616*da0073e9SAndroid Build Coastguard Worker def test_fuzz_symbolize(self): 2617*da0073e9SAndroid Build Coastguard Worker # generate some random addresses in the text section and make sure the 2618*da0073e9SAndroid Build Coastguard Worker # symbolizers do not throw exceptions/crash 2619*da0073e9SAndroid Build Coastguard Worker def get_text_sections(): 2620*da0073e9SAndroid Build Coastguard Worker text_sections = [] 2621*da0073e9SAndroid Build Coastguard Worker seen = set() 2622*da0073e9SAndroid Build Coastguard Worker for filename in os.listdir("/proc/self/map_files"): 2623*da0073e9SAndroid Build Coastguard Worker library = os.readlink("/proc/self/map_files/" + filename) 2624*da0073e9SAndroid Build Coastguard Worker if ".so" not in library or library in seen: 2625*da0073e9SAndroid Build Coastguard Worker continue 2626*da0073e9SAndroid Build Coastguard Worker seen.add(library) 2627*da0073e9SAndroid Build Coastguard Worker with open(os.path.join("/proc/self/map_files", library), "rb") as f: 2628*da0073e9SAndroid Build Coastguard Worker mm = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ) 2629*da0073e9SAndroid Build Coastguard Worker 2630*da0073e9SAndroid Build Coastguard Worker def unpack(fmt, offset): 2631*da0073e9SAndroid Build Coastguard Worker return struct.unpack( 2632*da0073e9SAndroid Build Coastguard Worker fmt, mm[offset : offset + struct.calcsize(fmt)] 2633*da0073e9SAndroid Build Coastguard Worker ) 2634*da0073e9SAndroid Build Coastguard Worker 2635*da0073e9SAndroid Build Coastguard Worker if mm[:4] != b"\x7fELF": 2636*da0073e9SAndroid Build Coastguard Worker continue 2637*da0073e9SAndroid Build Coastguard Worker (section_headers_start,) = unpack("Q", 40) 2638*da0073e9SAndroid Build Coastguard Worker (section_header_size,) = unpack("H", 58) 2639*da0073e9SAndroid Build Coastguard Worker (num_section_headers,) = unpack("H", 60) 2640*da0073e9SAndroid Build Coastguard Worker (shstrndx,) = unpack("H", 62) 2641*da0073e9SAndroid Build Coastguard Worker (shstrtab_offset,) = unpack( 2642*da0073e9SAndroid Build Coastguard Worker "Q", section_headers_start + shstrndx * section_header_size + 24 2643*da0073e9SAndroid Build Coastguard Worker ) 2644*da0073e9SAndroid Build Coastguard Worker for i in range(num_section_headers): 2645*da0073e9SAndroid Build Coastguard Worker (section_name_offset,) = unpack( 2646*da0073e9SAndroid Build Coastguard Worker "I", section_headers_start + i * section_header_size 2647*da0073e9SAndroid Build Coastguard Worker ) 2648*da0073e9SAndroid Build Coastguard Worker name_start = shstrtab_offset + section_name_offset 2649*da0073e9SAndroid Build Coastguard Worker section_name = mm[name_start : name_start + 6] 2650*da0073e9SAndroid Build Coastguard Worker if section_name != b".text\0": 2651*da0073e9SAndroid Build Coastguard Worker continue 2652*da0073e9SAndroid Build Coastguard Worker (section_offset,) = unpack( 2653*da0073e9SAndroid Build Coastguard Worker "Q", section_headers_start + i * section_header_size + 24 2654*da0073e9SAndroid Build Coastguard Worker ) 2655*da0073e9SAndroid Build Coastguard Worker (section_size,) = unpack( 2656*da0073e9SAndroid Build Coastguard Worker "Q", section_headers_start + i * section_header_size + 32 2657*da0073e9SAndroid Build Coastguard Worker ) 2658*da0073e9SAndroid Build Coastguard Worker start = int(filename.split("-")[0], 16) + section_offset 2659*da0073e9SAndroid Build Coastguard Worker text_sections.append((start, section_size)) 2660*da0073e9SAndroid Build Coastguard Worker break 2661*da0073e9SAndroid Build Coastguard Worker mm.close() 2662*da0073e9SAndroid Build Coastguard Worker return text_sections 2663*da0073e9SAndroid Build Coastguard Worker 2664*da0073e9SAndroid Build Coastguard Worker r = random.Random() 2665*da0073e9SAndroid Build Coastguard Worker r.seed(1) 2666*da0073e9SAndroid Build Coastguard Worker text_sections = get_text_sections() 2667*da0073e9SAndroid Build Coastguard Worker addrs = [] 2668*da0073e9SAndroid Build Coastguard Worker for i in range(200): 2669*da0073e9SAndroid Build Coastguard Worker s = r.randrange(0, len(text_sections)) 2670*da0073e9SAndroid Build Coastguard Worker start, size = text_sections[s] 2671*da0073e9SAndroid Build Coastguard Worker addr = r.randrange(start, start + size) 2672*da0073e9SAndroid Build Coastguard Worker addrs.append(addr) 2673*da0073e9SAndroid Build Coastguard Worker fast = torch._C._profiler.symbolize_addresses(addrs, "fast") 2674*da0073e9SAndroid Build Coastguard Worker dladdr = torch._C._profiler.symbolize_addresses(addrs, "dladdr") 2675*da0073e9SAndroid Build Coastguard Worker addr2line = torch._C._profiler.symbolize_addresses(addrs, "addr2line") 2676*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(fast), len(addrs)) 2677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(addr2line), len(fast)) 2678*da0073e9SAndroid Build Coastguard Worker 2679*da0073e9SAndroid Build Coastguard Worker 2680*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2681*da0073e9SAndroid Build Coastguard Worker run_tests() 2682