1# Owner(s): ["module: dynamo"] 2from unittest.mock import patch 3 4import torch 5import torch._dynamo.test_case 6import torch._dynamo.testing 7 8 9class RecompileTests(torch._dynamo.test_case.TestCase): 10 def test_automatic_dynamic_reduce_recompiles(self): 11 # Test the counterfactual, lots of recompiles without this config 12 def foo(x, y): 13 return x * y 14 15 def run_foo_6_times_and_count_recompiles(dynamic=None): 16 cnt = torch._dynamo.testing.CompileCounter() 17 18 x = torch.randn([2]) 19 y = torch.randn([2]) 20 opt = torch._dynamo.optimize(cnt, dynamic=dynamic)(foo) 21 opt(x, y) 22 x = torch.randn([3]) 23 y = torch.randn([3]) 24 opt(x, y) 25 x = torch.randn([4]) 26 y = torch.randn([4]) 27 opt(x, y) 28 opt(x, y) 29 x = torch.randn([5]) 30 y = torch.randn([5]) 31 opt(x, y) 32 opt(x, y) 33 x = torch.randn([6]) 34 y = torch.randn([6]) 35 opt(x, y) 36 37 return cnt 38 39 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 40 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 41 def run_without_automatic(): 42 return run_foo_6_times_and_count_recompiles() 43 44 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 45 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 46 def run_with_automatic(): 47 return run_foo_6_times_and_count_recompiles() 48 49 without = run_without_automatic() 50 self.assertEqual(without.frame_count, 5) 51 self.assertEqual(without.op_count, 5) 52 torch._dynamo.reset() 53 without = run_foo_6_times_and_count_recompiles(dynamic=False) 54 self.assertEqual(without.frame_count, 5) 55 self.assertEqual(without.op_count, 5) 56 torch._dynamo.reset() 57 with_automatic = run_with_automatic() 58 self.assertEqual(with_automatic.frame_count, 2) 59 self.assertEqual(with_automatic.op_count, 2) 60 torch._dynamo.reset() 61 with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None) 62 self.assertEqual(with_automatic.frame_count, 2) 63 self.assertEqual(with_automatic.op_count, 2) 64 torch._dynamo.reset() 65 with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True) 66 self.assertEqual(with_dynamic.frame_count, 1) 67 self.assertEqual(with_dynamic.op_count, 1) 68 69 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 70 def test_recompiles_true_false_flop(self): 71 # Test the counterfactual, lots of recompiles without this config 72 def foo(x, y): 73 if x: 74 return y * 2 75 else: 76 return y * y 77 78 def run_foo_6_times_and_count_recompiles(): 79 cnt = torch._dynamo.testing.CompileCounter() 80 81 opt = torch._dynamo.optimize(cnt, nopython=True)(foo) 82 83 x = True 84 y = torch.randn([2]) 85 opt(x, y) 86 x = False 87 y = torch.randn([2]) 88 opt(x, y) 89 x = True 90 y = torch.randn([3]) 91 opt(x, y) 92 x = True 93 y = torch.randn([4]) 94 opt(x, y) 95 x = True 96 y = torch.randn([5]) 97 opt(x, y) 98 99 return cnt 100 101 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 102 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 103 def run_without_automatic(): 104 return run_foo_6_times_and_count_recompiles() 105 106 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 107 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 108 def run_with_automatic(): 109 return run_foo_6_times_and_count_recompiles() 110 111 without = run_without_automatic() 112 self.assertEqual(without.frame_count, 5) 113 self.assertEqual(without.op_count, 5) 114 torch._dynamo.reset() 115 with_automatic = run_with_automatic() 116 self.assertEqual(with_automatic.frame_count, 3) 117 self.assertEqual(with_automatic.op_count, 3) 118 119 def test_automatic_dynamic_tensor_scalar_change(self): 120 # Test the counterfactual, lots of recompiles without this config 121 def foo(x, y): 122 return x * y 123 124 def run_foo_6_times_and_count_recompiles_swap_types(): 125 cnt = torch._dynamo.testing.CompileCounter() 126 127 x = torch.randn([2]) 128 y = torch.randn([2]) 129 opt = torch._dynamo.optimize(cnt)(foo) 130 opt(x, y) 131 x = torch.randn([3]) 132 y = 3 133 opt(x, y) 134 x = torch.randn([4]) 135 y = torch.randn([4]) 136 opt(x, y) 137 opt(x, y) 138 x = torch.randn([5]) 139 y = 4 140 opt(x, y) 141 opt(x, y) 142 x = torch.randn([6]) 143 y = torch.randn([6]) 144 opt(x, y) 145 146 return cnt 147 148 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 149 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 150 def run_without_automatic(): 151 return run_foo_6_times_and_count_recompiles_swap_types() 152 153 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 154 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 155 def run_with_automatic(): 156 return run_foo_6_times_and_count_recompiles_swap_types() 157 158 without = run_without_automatic() 159 self.assertEqual(without.frame_count, 5) 160 self.assertEqual(without.op_count, 5) 161 torch._dynamo.reset() 162 with_automatic = run_with_automatic() 163 self.assertEqual(with_automatic.frame_count, 3) 164 self.assertEqual(with_automatic.op_count, 3) 165 166 def test_aliasing_guard_failures(self): 167 def foo(a, b, c): 168 a.add_(b) 169 return c + 1 170 171 cnt = torch._dynamo.testing.CompileCounter() 172 compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) 173 174 x = torch.randn([3]) 175 y = torch.randn([3]) 176 z = torch.randn([3]) 177 cmp_result = compiled_foo( 178 x.clone().detach(), y.clone().detach(), z.clone().detach() 179 ) 180 eager_result = foo(x.clone().detach(), y.clone().detach(), z.clone().detach()) 181 self.assertEqual(cmp_result, eager_result) 182 self.assertEqual(cnt.frame_count, 1) 183 184 cmp_result = compiled_foo( 185 z.clone().detach(), y.clone().detach(), x.clone().detach() 186 ) 187 eager_result = foo(z.clone().detach(), y.clone().detach(), x.clone().detach()) 188 self.assertEqual(cmp_result, eager_result) 189 # No recompile, alias preserved 190 self.assertEqual(cnt.frame_count, 1) 191 192 x_clone = x.clone().detach() 193 cmp_result = compiled_foo(x_clone, y.clone().detach(), x_clone) 194 x_clone = x.clone().detach() 195 eager_result = compiled_foo(x_clone, y.clone().detach(), x_clone) 196 self.assertEqual(cmp_result, eager_result) 197 # Recompile, alias changed 198 self.assertEqual(cnt.frame_count, 2) 199 200 def test_aliasing_guard_failures_with_globals(self): 201 g1 = torch.randn([3]) 202 g2 = torch.randn([3]) 203 204 def foo(a): 205 a.add_(g1) 206 return g2 + 1 207 208 cnt = torch._dynamo.testing.CompileCounter() 209 compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) 210 211 z = torch.randn([3]) 212 cmp_result = compiled_foo(z.clone().detach()) 213 eager_result = foo(z.clone().detach()) 214 self.assertEqual(cmp_result, eager_result) 215 self.assertEqual(cnt.frame_count, 1) 216 217 g1 = g1.clone().detach() 218 cmp_result = compiled_foo(g1) 219 g1 = g1.clone().detach() 220 eager_result = compiled_foo(g1) 221 self.assertEqual(cmp_result, eager_result) 222 # Recompile, alias changed 223 self.assertEqual(cnt.frame_count, 2) 224 225 def test_dynamic_shape_parameter_recompile(self): 226 # Test the matrix multiplication with Parameters. 227 # Without the config assume_parameters_shapes_static_by_default, 228 # the torch.nn.Parameter shapes are assumed to be static which leads to recompilation 229 230 w = torch.nn.Parameter(torch.randn(3, 2)) 231 232 def foo(x): 233 return x @ w 234 235 def run_foo_6_times_and_count_recompiles(): 236 cnt = torch._dynamo.testing.CompileCounter() 237 238 opt = torch._dynamo.optimize(cnt, nopython=True)(foo) 239 240 x = torch.nn.Parameter(torch.randn(1, 3)) 241 opt(x) 242 x = torch.nn.Parameter(torch.randn(10, 3)) 243 opt(x) 244 x = torch.nn.Parameter(torch.randn(11, 3)) 245 opt(x) 246 x = torch.nn.Parameter(torch.randn(15, 3)) 247 opt(x) 248 x = torch.nn.Parameter(torch.randn(15, 3)) 249 opt(x) 250 251 return cnt 252 253 @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) 254 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 255 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 256 def run_static_comp_default_param(): 257 return run_foo_6_times_and_count_recompiles() 258 259 @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) 260 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 261 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 262 def run_dynamic_comp_default_param(): 263 return run_foo_6_times_and_count_recompiles() 264 265 @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) 266 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) 267 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 268 def run_static_comp_dynamic_param(): 269 return run_foo_6_times_and_count_recompiles() 270 271 @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) 272 @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) 273 @patch.object(torch._dynamo.config, "assume_static_by_default", True) 274 def run_dynamic_comp_dynamic_param(): 275 return run_foo_6_times_and_count_recompiles() 276 277 torch._dynamo.reset() 278 static_comp_default_param = run_static_comp_default_param() 279 self.assertEqual(static_comp_default_param.frame_count, 4) 280 self.assertEqual(static_comp_default_param.op_count, 4) 281 282 torch._dynamo.reset() 283 dynamic_comp_default_param = run_dynamic_comp_default_param() 284 self.assertEqual(dynamic_comp_default_param.frame_count, 4) 285 self.assertEqual(dynamic_comp_default_param.op_count, 4) 286 287 torch._dynamo.reset() 288 static_comp_dynamic_param = run_static_comp_dynamic_param() 289 self.assertEqual(static_comp_dynamic_param.frame_count, 4) 290 self.assertEqual(static_comp_dynamic_param.op_count, 4) 291 292 torch._dynamo.reset() 293 dynamic_comp_dynamic_param = run_dynamic_comp_dynamic_param() 294 self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2) 295 self.assertEqual(dynamic_comp_dynamic_param.op_count, 2) 296 297 def test_simple_module_recompile(self): 298 class SimpleDropout(torch.nn.Module): 299 def __init__(self) -> None: 300 super().__init__() 301 self.dropout = torch.nn.Dropout(0.5) 302 self.linear = torch.nn.Linear(10, 1) 303 304 def forward(self, x): 305 return self.dropout(self.linear(x)) 306 307 model = SimpleDropout() 308 x = torch.randn(10) 309 counter = torch._dynamo.testing.CompileCounter() 310 model = torch.compile(model, backend=counter, fullgraph=True) 311 for _ in range(20): 312 model.eval() 313 model(x) 314 model.train() 315 model(x) 316 self.assertEqual(counter.frame_count, 2) 317 318 @patch.object(torch._dynamo.config, "cache_size_limit", 2) 319 def test_no_recursive_compile_after_cache_limit_hit(self): 320 def f(x, n): 321 x = x + n 322 return g(x, n) 323 324 def g(x, n): 325 x = x + n 326 return h(x, n) 327 328 def h(x, n): 329 return x + n 330 331 counter = torch._dynamo.testing.CompileCounter() 332 opt_f = torch.compile(f, backend=counter, dynamic=False) 333 for i in range(10): 334 opt_f(torch.ones(3), i) 335 self.assertEqual(counter.frame_count, 2) 336 337 338if __name__ == "__main__": 339 from torch._dynamo.test_case import run_tests 340 341 run_tests() 342