1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerclass RecompileTests(torch._dynamo.test_case.TestCase): 10*da0073e9SAndroid Build Coastguard Worker def test_automatic_dynamic_reduce_recompiles(self): 11*da0073e9SAndroid Build Coastguard Worker # Test the counterfactual, lots of recompiles without this config 12*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 13*da0073e9SAndroid Build Coastguard Worker return x * y 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker def run_foo_6_times_and_count_recompiles(dynamic=None): 16*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker x = torch.randn([2]) 19*da0073e9SAndroid Build Coastguard Worker y = torch.randn([2]) 20*da0073e9SAndroid Build Coastguard Worker opt = torch._dynamo.optimize(cnt, dynamic=dynamic)(foo) 21*da0073e9SAndroid Build Coastguard Worker opt(x, y) 22*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3]) 23*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3]) 24*da0073e9SAndroid Build Coastguard Worker opt(x, y) 25*da0073e9SAndroid Build Coastguard Worker x = torch.randn([4]) 26*da0073e9SAndroid Build Coastguard Worker y = torch.randn([4]) 27*da0073e9SAndroid Build Coastguard Worker opt(x, y) 28*da0073e9SAndroid Build Coastguard Worker opt(x, y) 29*da0073e9SAndroid Build Coastguard Worker x = torch.randn([5]) 30*da0073e9SAndroid Build Coastguard Worker y = torch.randn([5]) 31*da0073e9SAndroid Build Coastguard Worker opt(x, y) 32*da0073e9SAndroid Build Coastguard Worker opt(x, y) 33*da0073e9SAndroid Build Coastguard Worker x = torch.randn([6]) 34*da0073e9SAndroid Build Coastguard Worker y = torch.randn([6]) 35*da0073e9SAndroid Build Coastguard Worker opt(x, y) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker return cnt 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 40*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 41*da0073e9SAndroid Build Coastguard Worker def run_without_automatic(): 42*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles() 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 45*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 46*da0073e9SAndroid Build Coastguard Worker def run_with_automatic(): 47*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles() 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker without = run_without_automatic() 50*da0073e9SAndroid Build Coastguard Worker self.assertEqual(without.frame_count, 5) 51*da0073e9SAndroid Build Coastguard Worker self.assertEqual(without.op_count, 5) 52*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 53*da0073e9SAndroid Build Coastguard Worker without = run_foo_6_times_and_count_recompiles(dynamic=False) 54*da0073e9SAndroid Build Coastguard Worker self.assertEqual(without.frame_count, 5) 55*da0073e9SAndroid Build Coastguard Worker self.assertEqual(without.op_count, 5) 56*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 57*da0073e9SAndroid Build Coastguard Worker with_automatic = run_with_automatic() 58*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_automatic.frame_count, 2) 59*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_automatic.op_count, 2) 60*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 61*da0073e9SAndroid Build Coastguard Worker with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None) 62*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_automatic.frame_count, 2) 63*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_automatic.op_count, 2) 64*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 65*da0073e9SAndroid Build Coastguard Worker with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True) 66*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_dynamic.frame_count, 1) 67*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_dynamic.op_count, 1) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 70*da0073e9SAndroid Build Coastguard Worker def test_recompiles_true_false_flop(self): 71*da0073e9SAndroid Build Coastguard Worker # Test the counterfactual, lots of recompiles without this config 72*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 73*da0073e9SAndroid Build Coastguard Worker if x: 74*da0073e9SAndroid Build Coastguard Worker return y * 2 75*da0073e9SAndroid Build Coastguard Worker else: 76*da0073e9SAndroid Build Coastguard Worker return y * y 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def run_foo_6_times_and_count_recompiles(): 79*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker opt = torch._dynamo.optimize(cnt, nopython=True)(foo) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker x = True 84*da0073e9SAndroid Build Coastguard Worker y = torch.randn([2]) 85*da0073e9SAndroid Build Coastguard Worker opt(x, y) 86*da0073e9SAndroid Build Coastguard Worker x = False 87*da0073e9SAndroid Build Coastguard Worker y = torch.randn([2]) 88*da0073e9SAndroid Build Coastguard Worker opt(x, y) 89*da0073e9SAndroid Build Coastguard Worker x = True 90*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3]) 91*da0073e9SAndroid Build Coastguard Worker opt(x, y) 92*da0073e9SAndroid Build Coastguard Worker x = True 93*da0073e9SAndroid Build Coastguard Worker y = torch.randn([4]) 94*da0073e9SAndroid Build Coastguard Worker opt(x, y) 95*da0073e9SAndroid Build Coastguard Worker x = True 96*da0073e9SAndroid Build Coastguard Worker y = torch.randn([5]) 97*da0073e9SAndroid Build Coastguard Worker opt(x, y) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker return cnt 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 102*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 103*da0073e9SAndroid Build Coastguard Worker def run_without_automatic(): 104*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles() 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 107*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 108*da0073e9SAndroid Build Coastguard Worker def run_with_automatic(): 109*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles() 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker without = run_without_automatic() 112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(without.frame_count, 5) 113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(without.op_count, 5) 114*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 115*da0073e9SAndroid Build Coastguard Worker with_automatic = run_with_automatic() 116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_automatic.frame_count, 3) 117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_automatic.op_count, 3) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker def test_automatic_dynamic_tensor_scalar_change(self): 120*da0073e9SAndroid Build Coastguard Worker # Test the counterfactual, lots of recompiles without this config 121*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 122*da0073e9SAndroid Build Coastguard Worker return x * y 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def run_foo_6_times_and_count_recompiles_swap_types(): 125*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker x = torch.randn([2]) 128*da0073e9SAndroid Build Coastguard Worker y = torch.randn([2]) 129*da0073e9SAndroid Build Coastguard Worker opt = torch._dynamo.optimize(cnt)(foo) 130*da0073e9SAndroid Build Coastguard Worker opt(x, y) 131*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3]) 132*da0073e9SAndroid Build Coastguard Worker y = 3 133*da0073e9SAndroid Build Coastguard Worker opt(x, y) 134*da0073e9SAndroid Build Coastguard Worker x = torch.randn([4]) 135*da0073e9SAndroid Build Coastguard Worker y = torch.randn([4]) 136*da0073e9SAndroid Build Coastguard Worker opt(x, y) 137*da0073e9SAndroid Build Coastguard Worker opt(x, y) 138*da0073e9SAndroid Build Coastguard Worker x = torch.randn([5]) 139*da0073e9SAndroid Build Coastguard Worker y = 4 140*da0073e9SAndroid Build Coastguard Worker opt(x, y) 141*da0073e9SAndroid Build Coastguard Worker opt(x, y) 142*da0073e9SAndroid Build Coastguard Worker x = torch.randn([6]) 143*da0073e9SAndroid Build Coastguard Worker y = torch.randn([6]) 144*da0073e9SAndroid Build Coastguard Worker opt(x, y) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker return cnt 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 149*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 150*da0073e9SAndroid Build Coastguard Worker def run_without_automatic(): 151*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles_swap_types() 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 154*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 155*da0073e9SAndroid Build Coastguard Worker def run_with_automatic(): 156*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles_swap_types() 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker without = run_without_automatic() 159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(without.frame_count, 5) 160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(without.op_count, 5) 161*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 162*da0073e9SAndroid Build Coastguard Worker with_automatic = run_with_automatic() 163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_automatic.frame_count, 3) 164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_automatic.op_count, 3) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker def test_aliasing_guard_failures(self): 167*da0073e9SAndroid Build Coastguard Worker def foo(a, b, c): 168*da0073e9SAndroid Build Coastguard Worker a.add_(b) 169*da0073e9SAndroid Build Coastguard Worker return c + 1 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 172*da0073e9SAndroid Build Coastguard Worker compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3]) 175*da0073e9SAndroid Build Coastguard Worker y = torch.randn([3]) 176*da0073e9SAndroid Build Coastguard Worker z = torch.randn([3]) 177*da0073e9SAndroid Build Coastguard Worker cmp_result = compiled_foo( 178*da0073e9SAndroid Build Coastguard Worker x.clone().detach(), y.clone().detach(), z.clone().detach() 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker eager_result = foo(x.clone().detach(), y.clone().detach(), z.clone().detach()) 181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cmp_result, eager_result) 182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker cmp_result = compiled_foo( 185*da0073e9SAndroid Build Coastguard Worker z.clone().detach(), y.clone().detach(), x.clone().detach() 186*da0073e9SAndroid Build Coastguard Worker ) 187*da0073e9SAndroid Build Coastguard Worker eager_result = foo(z.clone().detach(), y.clone().detach(), x.clone().detach()) 188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cmp_result, eager_result) 189*da0073e9SAndroid Build Coastguard Worker # No recompile, alias preserved 190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker x_clone = x.clone().detach() 193*da0073e9SAndroid Build Coastguard Worker cmp_result = compiled_foo(x_clone, y.clone().detach(), x_clone) 194*da0073e9SAndroid Build Coastguard Worker x_clone = x.clone().detach() 195*da0073e9SAndroid Build Coastguard Worker eager_result = compiled_foo(x_clone, y.clone().detach(), x_clone) 196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cmp_result, eager_result) 197*da0073e9SAndroid Build Coastguard Worker # Recompile, alias changed 198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker def test_aliasing_guard_failures_with_globals(self): 201*da0073e9SAndroid Build Coastguard Worker g1 = torch.randn([3]) 202*da0073e9SAndroid Build Coastguard Worker g2 = torch.randn([3]) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker def foo(a): 205*da0073e9SAndroid Build Coastguard Worker a.add_(g1) 206*da0073e9SAndroid Build Coastguard Worker return g2 + 1 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 209*da0073e9SAndroid Build Coastguard Worker compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker z = torch.randn([3]) 212*da0073e9SAndroid Build Coastguard Worker cmp_result = compiled_foo(z.clone().detach()) 213*da0073e9SAndroid Build Coastguard Worker eager_result = foo(z.clone().detach()) 214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cmp_result, eager_result) 215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker g1 = g1.clone().detach() 218*da0073e9SAndroid Build Coastguard Worker cmp_result = compiled_foo(g1) 219*da0073e9SAndroid Build Coastguard Worker g1 = g1.clone().detach() 220*da0073e9SAndroid Build Coastguard Worker eager_result = compiled_foo(g1) 221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cmp_result, eager_result) 222*da0073e9SAndroid Build Coastguard Worker # Recompile, alias changed 223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker def test_dynamic_shape_parameter_recompile(self): 226*da0073e9SAndroid Build Coastguard Worker # Test the matrix multiplication with Parameters. 227*da0073e9SAndroid Build Coastguard Worker # Without the config assume_parameters_shapes_static_by_default, 228*da0073e9SAndroid Build Coastguard Worker # the torch.nn.Parameter shapes are assumed to be static which leads to recompilation 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker w = torch.nn.Parameter(torch.randn(3, 2)) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def foo(x): 233*da0073e9SAndroid Build Coastguard Worker return x @ w 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker def run_foo_6_times_and_count_recompiles(): 236*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker opt = torch._dynamo.optimize(cnt, nopython=True)(foo) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker x = torch.nn.Parameter(torch.randn(1, 3)) 241*da0073e9SAndroid Build Coastguard Worker opt(x) 242*da0073e9SAndroid Build Coastguard Worker x = torch.nn.Parameter(torch.randn(10, 3)) 243*da0073e9SAndroid Build Coastguard Worker opt(x) 244*da0073e9SAndroid Build Coastguard Worker x = torch.nn.Parameter(torch.randn(11, 3)) 245*da0073e9SAndroid Build Coastguard Worker opt(x) 246*da0073e9SAndroid Build Coastguard Worker x = torch.nn.Parameter(torch.randn(15, 3)) 247*da0073e9SAndroid Build Coastguard Worker opt(x) 248*da0073e9SAndroid Build Coastguard Worker x = torch.nn.Parameter(torch.randn(15, 3)) 249*da0073e9SAndroid Build Coastguard Worker opt(x) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker return cnt 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) 254*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 255*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 256*da0073e9SAndroid Build Coastguard Worker def run_static_comp_default_param(): 257*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles() 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) 260*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 261*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 262*da0073e9SAndroid Build Coastguard Worker def run_dynamic_comp_default_param(): 263*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles() 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) 266*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 267*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 268*da0073e9SAndroid Build Coastguard Worker def run_static_comp_dynamic_param(): 269*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles() 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) 272*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 273*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "assume_static_by_default", True) 274*da0073e9SAndroid Build Coastguard Worker def run_dynamic_comp_dynamic_param(): 275*da0073e9SAndroid Build Coastguard Worker return run_foo_6_times_and_count_recompiles() 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 278*da0073e9SAndroid Build Coastguard Worker static_comp_default_param = run_static_comp_default_param() 279*da0073e9SAndroid Build Coastguard Worker self.assertEqual(static_comp_default_param.frame_count, 4) 280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(static_comp_default_param.op_count, 4) 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 283*da0073e9SAndroid Build Coastguard Worker dynamic_comp_default_param = run_dynamic_comp_default_param() 284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamic_comp_default_param.frame_count, 4) 285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamic_comp_default_param.op_count, 4) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 288*da0073e9SAndroid Build Coastguard Worker static_comp_dynamic_param = run_static_comp_dynamic_param() 289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(static_comp_dynamic_param.frame_count, 4) 290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(static_comp_dynamic_param.op_count, 4) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 293*da0073e9SAndroid Build Coastguard Worker dynamic_comp_dynamic_param = run_dynamic_comp_dynamic_param() 294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2) 295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamic_comp_dynamic_param.op_count, 2) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def test_simple_module_recompile(self): 298*da0073e9SAndroid Build Coastguard Worker class SimpleDropout(torch.nn.Module): 299*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 300*da0073e9SAndroid Build Coastguard Worker super().__init__() 301*da0073e9SAndroid Build Coastguard Worker self.dropout = torch.nn.Dropout(0.5) 302*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(10, 1) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 305*da0073e9SAndroid Build Coastguard Worker return self.dropout(self.linear(x)) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker model = SimpleDropout() 308*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 309*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 310*da0073e9SAndroid Build Coastguard Worker model = torch.compile(model, backend=counter, fullgraph=True) 311*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 312*da0073e9SAndroid Build Coastguard Worker model.eval() 313*da0073e9SAndroid Build Coastguard Worker model(x) 314*da0073e9SAndroid Build Coastguard Worker model.train() 315*da0073e9SAndroid Build Coastguard Worker model(x) 316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker @patch.object(torch._dynamo.config, "cache_size_limit", 2) 319*da0073e9SAndroid Build Coastguard Worker def test_no_recursive_compile_after_cache_limit_hit(self): 320*da0073e9SAndroid Build Coastguard Worker def f(x, n): 321*da0073e9SAndroid Build Coastguard Worker x = x + n 322*da0073e9SAndroid Build Coastguard Worker return g(x, n) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def g(x, n): 325*da0073e9SAndroid Build Coastguard Worker x = x + n 326*da0073e9SAndroid Build Coastguard Worker return h(x, n) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker def h(x, n): 329*da0073e9SAndroid Build Coastguard Worker return x + n 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 332*da0073e9SAndroid Build Coastguard Worker opt_f = torch.compile(f, backend=counter, dynamic=False) 333*da0073e9SAndroid Build Coastguard Worker for i in range(10): 334*da0073e9SAndroid Build Coastguard Worker opt_f(torch.ones(3), i) 335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 339*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker run_tests() 342