1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport logging 3*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.utils 9*da0073e9SAndroid Build Coastguard Workerimport torch._logging 10*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import dynamo_timed 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TemporaryFileName 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerclass DynamoProfilerTests(torch._dynamo.test_case.TestCase): 15*da0073e9SAndroid Build Coastguard Worker def test_dynamo_timed_profiling_isolated(self): 16*da0073e9SAndroid Build Coastguard Worker # dynamo_timed functions should appear in profile traces. 17*da0073e9SAndroid Build Coastguard Worker def inner_fn(x): 18*da0073e9SAndroid Build Coastguard Worker with dynamo_timed("inner_fn"): 19*da0073e9SAndroid Build Coastguard Worker return x.sin() 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker def outer_fn(x, y): 22*da0073e9SAndroid Build Coastguard Worker return inner_fn(x) * y 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand((2, 2)) for _ in range(2)) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(with_stack=False) as prof: 27*da0073e9SAndroid Build Coastguard Worker outer_fn(x, y) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 30*da0073e9SAndroid Build Coastguard Worker any("inner_fn (dynamo_timed)" in evt.name for evt in prof.events()) 31*da0073e9SAndroid Build Coastguard Worker ) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker def test_dynamo_timed_profiling_backend_compile(self): 34*da0073e9SAndroid Build Coastguard Worker # dynamo_timed functions should appear in profile traces. 35*da0073e9SAndroid Build Coastguard Worker # this checks whether these actually appear in actual dynamo execution. 36*da0073e9SAndroid Build Coastguard Worker # "backend_compile" is just chosen as an example; if it gets renamed 37*da0073e9SAndroid Build Coastguard Worker # this test can be replaced or deleted 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker fn_name = "call_user_compiler" 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 42*da0073e9SAndroid Build Coastguard Worker return x.sin() * y.cos() 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker x, y = (torch.rand((2, 2)) for _ in range(2)) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(with_stack=False) as prof: 47*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize("aot_eager")(fn)(x, y) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 50*da0073e9SAndroid Build Coastguard Worker any(f"{fn_name} (dynamo_timed)" in evt.name for evt in prof.events()) 51*da0073e9SAndroid Build Coastguard Worker ) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", False) 54*da0073e9SAndroid Build Coastguard Worker def test_profile_dynamic_shapes_runtime(self): 55*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 56*da0073e9SAndroid Build Coastguard Worker return x @ y + z 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("aot_eager", dynamic=True, nopython=True)(fn) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker inputs = [ 61*da0073e9SAndroid Build Coastguard Worker (torch.rand(a, b), torch.rand(b, c), torch.rand(a, c)) 62*da0073e9SAndroid Build Coastguard Worker for (a, b, c) in [(15, 16, 17), (15, 15, 16), (16, 16, 16)] 63*da0073e9SAndroid Build Coastguard Worker ] 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker opt_fn(*inputs[0]) 66*da0073e9SAndroid Build Coastguard Worker opt_fn(*inputs[1]) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(record_shapes=True): 69*da0073e9SAndroid Build Coastguard Worker opt_fn(*inputs[2]) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", False) 72*da0073e9SAndroid Build Coastguard Worker def test_profile_dynamic_shapes_compilation(self): 73*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 74*da0073e9SAndroid Build Coastguard Worker return x @ y + z 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("aot_eager", dynamic=True, nopython=True)(fn) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker inputs = (torch.rand(15, 16), torch.rand(16, 17), torch.rand(15, 17)) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(record_shapes=True): 81*da0073e9SAndroid Build Coastguard Worker opt_fn(*inputs) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", False) 84*da0073e9SAndroid Build Coastguard Worker def test_profile_dynamic_shapes_list_compilation(self): 85*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 86*da0073e9SAndroid Build Coastguard Worker return torch.cat([x, y], dim=0) + z 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("aot_eager", dynamic=True, nopython=True)(fn) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker inputs = (torch.rand(4, 16), torch.rand(12, 16), torch.rand(16, 16)) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(record_shapes=True): 93*da0073e9SAndroid Build Coastguard Worker opt_fn(*inputs) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def test_execution_trace_dynamic_shapes(self): 96*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 97*da0073e9SAndroid Build Coastguard Worker return x @ y + z 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker et = torch.profiler.ExecutionTraceObserver() 100*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, dynamic=True, backend="aot_eager") 101*da0073e9SAndroid Build Coastguard Worker inputs = [torch.rand((4, 4)) for _ in range(3)] 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 104*da0073e9SAndroid Build Coastguard Worker et.register_callback(fname) 105*da0073e9SAndroid Build Coastguard Worker et.start() 106*da0073e9SAndroid Build Coastguard Worker out = opt_fn(*inputs) 107*da0073e9SAndroid Build Coastguard Worker et.stop() 108*da0073e9SAndroid Build Coastguard Worker et.unregister_callback() 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker def test_profiler_cache_lookup(self): 111*da0073e9SAndroid Build Coastguard Worker def fn(x): 112*da0073e9SAndroid Build Coastguard Worker y = x**2 113*da0073e9SAndroid Build Coastguard Worker y = y + 2 114*da0073e9SAndroid Build Coastguard Worker z = y**3 115*da0073e9SAndroid Build Coastguard Worker return z 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker for profiler, get_events in ( 118*da0073e9SAndroid Build Coastguard Worker (torch.autograd.profiler.profile, lambda prof: prof.function_events), 119*da0073e9SAndroid Build Coastguard Worker (torch.profiler.profiler.profile, lambda prof: prof.events()), 120*da0073e9SAndroid Build Coastguard Worker ): 121*da0073e9SAndroid Build Coastguard Worker x = torch.randn((2, 2), requires_grad=True) 122*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 123*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="aot_eager") 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker # warmup 126*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker with profiler() as prof: 129*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 130*da0073e9SAndroid Build Coastguard Worker events = list( 131*da0073e9SAndroid Build Coastguard Worker filter( 132*da0073e9SAndroid Build Coastguard Worker lambda event: "TorchDynamo Cache Lookup" in event.name, 133*da0073e9SAndroid Build Coastguard Worker get_events(prof), 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker ) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 138*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 139*da0073e9SAndroid Build Coastguard Worker len(events) == 1, 140*da0073e9SAndroid Build Coastguard Worker "Expected one lookup profiler event for one opt_fn run", 141*da0073e9SAndroid Build Coastguard Worker ) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker def test_profiler_cache_lookup_profiler_step(self): 144*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 145*da0073e9SAndroid Build Coastguard Worker return torch.add(torch.sub(x, y), z) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("aot_eager")(fn) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker ( 150*da0073e9SAndroid Build Coastguard Worker x, 151*da0073e9SAndroid Build Coastguard Worker y, 152*da0073e9SAndroid Build Coastguard Worker z, 153*da0073e9SAndroid Build Coastguard Worker ) = (torch.rand(4, 4) for _ in range(3)) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker prof = torch.profiler.profile( 156*da0073e9SAndroid Build Coastguard Worker schedule=torch.profiler.schedule(wait=2, warmup=2, active=2, repeat=1) 157*da0073e9SAndroid Build Coastguard Worker ) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 160*da0073e9SAndroid Build Coastguard Worker opt_fn(x, y, z) 161*da0073e9SAndroid Build Coastguard Worker prof.step() 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 164*da0073e9SAndroid Build Coastguard Worker any(e.name == "TorchDynamo Cache Lookup" for e in prof.events()) 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker def test_profiler_dynamo_compiled_region(self): 168*da0073e9SAndroid Build Coastguard Worker torch._logging.set_logs(dynamo=logging.INFO) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 171*da0073e9SAndroid Build Coastguard Worker r = y.sum(dim=1) 172*da0073e9SAndroid Build Coastguard Worker print(r.shape) 173*da0073e9SAndroid Build Coastguard Worker return x * r 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker fn_c = torch.compile(fn) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile(record_shapes=True) as prof: 178*da0073e9SAndroid Build Coastguard Worker fn_c( 179*da0073e9SAndroid Build Coastguard Worker torch.randn(10), 180*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 10), 181*da0073e9SAndroid Build Coastguard Worker ) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker fn_c( 184*da0073e9SAndroid Build Coastguard Worker torch.randn(10), 185*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 15), 186*da0073e9SAndroid Build Coastguard Worker ) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker for e in prof.events(): 189*da0073e9SAndroid Build Coastguard Worker if e.name == "Torch-Compiled Region": 190*da0073e9SAndroid Build Coastguard Worker print(e.kwinputs) 191*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 192*da0073e9SAndroid Build Coastguard Worker any( 193*da0073e9SAndroid Build Coastguard Worker e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "0/0_1" 194*da0073e9SAndroid Build Coastguard Worker for e in prof.events() 195*da0073e9SAndroid Build Coastguard Worker ) 196*da0073e9SAndroid Build Coastguard Worker ) 197*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 198*da0073e9SAndroid Build Coastguard Worker any( 199*da0073e9SAndroid Build Coastguard Worker e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "1/0" 200*da0073e9SAndroid Build Coastguard Worker for e in prof.events() 201*da0073e9SAndroid Build Coastguard Worker ) 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 206*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker run_tests() 209