# Owner(s): ["module: dynamo"] from unittest.mock import patch import torch import torch._dynamo.test_case import torch._dynamo.testing class RecompileTests(torch._dynamo.test_case.TestCase): def test_automatic_dynamic_reduce_recompiles(self): # Test the counterfactual, lots of recompiles without this config def foo(x, y): return x * y def run_foo_6_times_and_count_recompiles(dynamic=None): cnt = torch._dynamo.testing.CompileCounter() x = torch.randn([2]) y = torch.randn([2]) opt = torch._dynamo.optimize(cnt, dynamic=dynamic)(foo) opt(x, y) x = torch.randn([3]) y = torch.randn([3]) opt(x, y) x = torch.randn([4]) y = torch.randn([4]) opt(x, y) opt(x, y) x = torch.randn([5]) y = torch.randn([5]) opt(x, y) opt(x, y) x = torch.randn([6]) y = torch.randn([6]) opt(x, y) return cnt @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_without_automatic(): return run_foo_6_times_and_count_recompiles() @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_with_automatic(): return run_foo_6_times_and_count_recompiles() without = run_without_automatic() self.assertEqual(without.frame_count, 5) self.assertEqual(without.op_count, 5) torch._dynamo.reset() without = run_foo_6_times_and_count_recompiles(dynamic=False) self.assertEqual(without.frame_count, 5) self.assertEqual(without.op_count, 5) torch._dynamo.reset() with_automatic = run_with_automatic() self.assertEqual(with_automatic.frame_count, 2) self.assertEqual(with_automatic.op_count, 2) torch._dynamo.reset() with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None) self.assertEqual(with_automatic.frame_count, 2) self.assertEqual(with_automatic.op_count, 2) torch._dynamo.reset() with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True) self.assertEqual(with_dynamic.frame_count, 1) self.assertEqual(with_dynamic.op_count, 1) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def test_recompiles_true_false_flop(self): # Test the counterfactual, lots of recompiles without this config def foo(x, y): if x: return y * 2 else: return y * y def run_foo_6_times_and_count_recompiles(): cnt = torch._dynamo.testing.CompileCounter() opt = torch._dynamo.optimize(cnt, nopython=True)(foo) x = True y = torch.randn([2]) opt(x, y) x = False y = torch.randn([2]) opt(x, y) x = True y = torch.randn([3]) opt(x, y) x = True y = torch.randn([4]) opt(x, y) x = True y = torch.randn([5]) opt(x, y) return cnt @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_without_automatic(): return run_foo_6_times_and_count_recompiles() @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_with_automatic(): return run_foo_6_times_and_count_recompiles() without = run_without_automatic() self.assertEqual(without.frame_count, 5) self.assertEqual(without.op_count, 5) torch._dynamo.reset() with_automatic = run_with_automatic() self.assertEqual(with_automatic.frame_count, 3) self.assertEqual(with_automatic.op_count, 3) def test_automatic_dynamic_tensor_scalar_change(self): # Test the counterfactual, lots of recompiles without this config def foo(x, y): return x * y def run_foo_6_times_and_count_recompiles_swap_types(): cnt = torch._dynamo.testing.CompileCounter() x = torch.randn([2]) y = torch.randn([2]) opt = torch._dynamo.optimize(cnt)(foo) opt(x, y) x = torch.randn([3]) y = 3 opt(x, y) x = torch.randn([4]) y = torch.randn([4]) opt(x, y) opt(x, y) x = torch.randn([5]) y = 4 opt(x, y) opt(x, y) x = torch.randn([6]) y = torch.randn([6]) opt(x, y) return cnt @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_without_automatic(): return run_foo_6_times_and_count_recompiles_swap_types() @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_with_automatic(): return run_foo_6_times_and_count_recompiles_swap_types() without = run_without_automatic() self.assertEqual(without.frame_count, 5) self.assertEqual(without.op_count, 5) torch._dynamo.reset() with_automatic = run_with_automatic() self.assertEqual(with_automatic.frame_count, 3) self.assertEqual(with_automatic.op_count, 3) def test_aliasing_guard_failures(self): def foo(a, b, c): a.add_(b) return c + 1 cnt = torch._dynamo.testing.CompileCounter() compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) x = torch.randn([3]) y = torch.randn([3]) z = torch.randn([3]) cmp_result = compiled_foo( x.clone().detach(), y.clone().detach(), z.clone().detach() ) eager_result = foo(x.clone().detach(), y.clone().detach(), z.clone().detach()) self.assertEqual(cmp_result, eager_result) self.assertEqual(cnt.frame_count, 1) cmp_result = compiled_foo( z.clone().detach(), y.clone().detach(), x.clone().detach() ) eager_result = foo(z.clone().detach(), y.clone().detach(), x.clone().detach()) self.assertEqual(cmp_result, eager_result) # No recompile, alias preserved self.assertEqual(cnt.frame_count, 1) x_clone = x.clone().detach() cmp_result = compiled_foo(x_clone, y.clone().detach(), x_clone) x_clone = x.clone().detach() eager_result = compiled_foo(x_clone, y.clone().detach(), x_clone) self.assertEqual(cmp_result, eager_result) # Recompile, alias changed self.assertEqual(cnt.frame_count, 2) def test_aliasing_guard_failures_with_globals(self): g1 = torch.randn([3]) g2 = torch.randn([3]) def foo(a): a.add_(g1) return g2 + 1 cnt = torch._dynamo.testing.CompileCounter() compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) z = torch.randn([3]) cmp_result = compiled_foo(z.clone().detach()) eager_result = foo(z.clone().detach()) self.assertEqual(cmp_result, eager_result) self.assertEqual(cnt.frame_count, 1) g1 = g1.clone().detach() cmp_result = compiled_foo(g1) g1 = g1.clone().detach() eager_result = compiled_foo(g1) self.assertEqual(cmp_result, eager_result) # Recompile, alias changed self.assertEqual(cnt.frame_count, 2) def test_dynamic_shape_parameter_recompile(self): # Test the matrix multiplication with Parameters. # Without the config assume_parameters_shapes_static_by_default, # the torch.nn.Parameter shapes are assumed to be static which leads to recompilation w = torch.nn.Parameter(torch.randn(3, 2)) def foo(x): return x @ w def run_foo_6_times_and_count_recompiles(): cnt = torch._dynamo.testing.CompileCounter() opt = torch._dynamo.optimize(cnt, nopython=True)(foo) x = torch.nn.Parameter(torch.randn(1, 3)) opt(x) x = torch.nn.Parameter(torch.randn(10, 3)) opt(x) x = torch.nn.Parameter(torch.randn(11, 3)) opt(x) x = torch.nn.Parameter(torch.randn(15, 3)) opt(x) x = torch.nn.Parameter(torch.randn(15, 3)) opt(x) return cnt @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_static_comp_default_param(): return run_foo_6_times_and_count_recompiles() @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True) @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_dynamic_comp_default_param(): return run_foo_6_times_and_count_recompiles() @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_static_comp_dynamic_param(): return run_foo_6_times_and_count_recompiles() @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False) @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True) @patch.object(torch._dynamo.config, "assume_static_by_default", True) def run_dynamic_comp_dynamic_param(): return run_foo_6_times_and_count_recompiles() torch._dynamo.reset() static_comp_default_param = run_static_comp_default_param() self.assertEqual(static_comp_default_param.frame_count, 4) self.assertEqual(static_comp_default_param.op_count, 4) torch._dynamo.reset() dynamic_comp_default_param = run_dynamic_comp_default_param() self.assertEqual(dynamic_comp_default_param.frame_count, 4) self.assertEqual(dynamic_comp_default_param.op_count, 4) torch._dynamo.reset() static_comp_dynamic_param = run_static_comp_dynamic_param() self.assertEqual(static_comp_dynamic_param.frame_count, 4) self.assertEqual(static_comp_dynamic_param.op_count, 4) torch._dynamo.reset() dynamic_comp_dynamic_param = run_dynamic_comp_dynamic_param() self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2) self.assertEqual(dynamic_comp_dynamic_param.op_count, 2) def test_simple_module_recompile(self): class SimpleDropout(torch.nn.Module): def __init__(self) -> None: super().__init__() self.dropout = torch.nn.Dropout(0.5) self.linear = torch.nn.Linear(10, 1) def forward(self, x): return self.dropout(self.linear(x)) model = SimpleDropout() x = torch.randn(10) counter = torch._dynamo.testing.CompileCounter() model = torch.compile(model, backend=counter, fullgraph=True) for _ in range(20): model.eval() model(x) model.train() model(x) self.assertEqual(counter.frame_count, 2) @patch.object(torch._dynamo.config, "cache_size_limit", 2) def test_no_recursive_compile_after_cache_limit_hit(self): def f(x, n): x = x + n return g(x, n) def g(x, n): x = x + n return h(x, n) def h(x, n): return x + n counter = torch._dynamo.testing.CompileCounter() opt_f = torch.compile(f, backend=counter, dynamic=False) for i in range(10): opt_f(torch.ones(3), i) self.assertEqual(counter.frame_count, 2) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()