xref: /aosp_15_r20/external/pytorch/test/dynamo/test_recompiles.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2from unittest.mock import patch
3
4import torch
5import torch._dynamo.test_case
6import torch._dynamo.testing
7
8
9class RecompileTests(torch._dynamo.test_case.TestCase):
10    def test_automatic_dynamic_reduce_recompiles(self):
11        # Test the counterfactual, lots of recompiles without this config
12        def foo(x, y):
13            return x * y
14
15        def run_foo_6_times_and_count_recompiles(dynamic=None):
16            cnt = torch._dynamo.testing.CompileCounter()
17
18            x = torch.randn([2])
19            y = torch.randn([2])
20            opt = torch._dynamo.optimize(cnt, dynamic=dynamic)(foo)
21            opt(x, y)
22            x = torch.randn([3])
23            y = torch.randn([3])
24            opt(x, y)
25            x = torch.randn([4])
26            y = torch.randn([4])
27            opt(x, y)
28            opt(x, y)
29            x = torch.randn([5])
30            y = torch.randn([5])
31            opt(x, y)
32            opt(x, y)
33            x = torch.randn([6])
34            y = torch.randn([6])
35            opt(x, y)
36
37            return cnt
38
39        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
40        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
41        def run_without_automatic():
42            return run_foo_6_times_and_count_recompiles()
43
44        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
45        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
46        def run_with_automatic():
47            return run_foo_6_times_and_count_recompiles()
48
49        without = run_without_automatic()
50        self.assertEqual(without.frame_count, 5)
51        self.assertEqual(without.op_count, 5)
52        torch._dynamo.reset()
53        without = run_foo_6_times_and_count_recompiles(dynamic=False)
54        self.assertEqual(without.frame_count, 5)
55        self.assertEqual(without.op_count, 5)
56        torch._dynamo.reset()
57        with_automatic = run_with_automatic()
58        self.assertEqual(with_automatic.frame_count, 2)
59        self.assertEqual(with_automatic.op_count, 2)
60        torch._dynamo.reset()
61        with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None)
62        self.assertEqual(with_automatic.frame_count, 2)
63        self.assertEqual(with_automatic.op_count, 2)
64        torch._dynamo.reset()
65        with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True)
66        self.assertEqual(with_dynamic.frame_count, 1)
67        self.assertEqual(with_dynamic.op_count, 1)
68
69    @patch.object(torch._dynamo.config, "assume_static_by_default", True)
70    def test_recompiles_true_false_flop(self):
71        # Test the counterfactual, lots of recompiles without this config
72        def foo(x, y):
73            if x:
74                return y * 2
75            else:
76                return y * y
77
78        def run_foo_6_times_and_count_recompiles():
79            cnt = torch._dynamo.testing.CompileCounter()
80
81            opt = torch._dynamo.optimize(cnt, nopython=True)(foo)
82
83            x = True
84            y = torch.randn([2])
85            opt(x, y)
86            x = False
87            y = torch.randn([2])
88            opt(x, y)
89            x = True
90            y = torch.randn([3])
91            opt(x, y)
92            x = True
93            y = torch.randn([4])
94            opt(x, y)
95            x = True
96            y = torch.randn([5])
97            opt(x, y)
98
99            return cnt
100
101        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
102        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
103        def run_without_automatic():
104            return run_foo_6_times_and_count_recompiles()
105
106        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
107        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
108        def run_with_automatic():
109            return run_foo_6_times_and_count_recompiles()
110
111        without = run_without_automatic()
112        self.assertEqual(without.frame_count, 5)
113        self.assertEqual(without.op_count, 5)
114        torch._dynamo.reset()
115        with_automatic = run_with_automatic()
116        self.assertEqual(with_automatic.frame_count, 3)
117        self.assertEqual(with_automatic.op_count, 3)
118
119    def test_automatic_dynamic_tensor_scalar_change(self):
120        # Test the counterfactual, lots of recompiles without this config
121        def foo(x, y):
122            return x * y
123
124        def run_foo_6_times_and_count_recompiles_swap_types():
125            cnt = torch._dynamo.testing.CompileCounter()
126
127            x = torch.randn([2])
128            y = torch.randn([2])
129            opt = torch._dynamo.optimize(cnt)(foo)
130            opt(x, y)
131            x = torch.randn([3])
132            y = 3
133            opt(x, y)
134            x = torch.randn([4])
135            y = torch.randn([4])
136            opt(x, y)
137            opt(x, y)
138            x = torch.randn([5])
139            y = 4
140            opt(x, y)
141            opt(x, y)
142            x = torch.randn([6])
143            y = torch.randn([6])
144            opt(x, y)
145
146            return cnt
147
148        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
149        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
150        def run_without_automatic():
151            return run_foo_6_times_and_count_recompiles_swap_types()
152
153        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
154        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
155        def run_with_automatic():
156            return run_foo_6_times_and_count_recompiles_swap_types()
157
158        without = run_without_automatic()
159        self.assertEqual(without.frame_count, 5)
160        self.assertEqual(without.op_count, 5)
161        torch._dynamo.reset()
162        with_automatic = run_with_automatic()
163        self.assertEqual(with_automatic.frame_count, 3)
164        self.assertEqual(with_automatic.op_count, 3)
165
166    def test_aliasing_guard_failures(self):
167        def foo(a, b, c):
168            a.add_(b)
169            return c + 1
170
171        cnt = torch._dynamo.testing.CompileCounter()
172        compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
173
174        x = torch.randn([3])
175        y = torch.randn([3])
176        z = torch.randn([3])
177        cmp_result = compiled_foo(
178            x.clone().detach(), y.clone().detach(), z.clone().detach()
179        )
180        eager_result = foo(x.clone().detach(), y.clone().detach(), z.clone().detach())
181        self.assertEqual(cmp_result, eager_result)
182        self.assertEqual(cnt.frame_count, 1)
183
184        cmp_result = compiled_foo(
185            z.clone().detach(), y.clone().detach(), x.clone().detach()
186        )
187        eager_result = foo(z.clone().detach(), y.clone().detach(), x.clone().detach())
188        self.assertEqual(cmp_result, eager_result)
189        # No recompile, alias preserved
190        self.assertEqual(cnt.frame_count, 1)
191
192        x_clone = x.clone().detach()
193        cmp_result = compiled_foo(x_clone, y.clone().detach(), x_clone)
194        x_clone = x.clone().detach()
195        eager_result = compiled_foo(x_clone, y.clone().detach(), x_clone)
196        self.assertEqual(cmp_result, eager_result)
197        # Recompile, alias changed
198        self.assertEqual(cnt.frame_count, 2)
199
200    def test_aliasing_guard_failures_with_globals(self):
201        g1 = torch.randn([3])
202        g2 = torch.randn([3])
203
204        def foo(a):
205            a.add_(g1)
206            return g2 + 1
207
208        cnt = torch._dynamo.testing.CompileCounter()
209        compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
210
211        z = torch.randn([3])
212        cmp_result = compiled_foo(z.clone().detach())
213        eager_result = foo(z.clone().detach())
214        self.assertEqual(cmp_result, eager_result)
215        self.assertEqual(cnt.frame_count, 1)
216
217        g1 = g1.clone().detach()
218        cmp_result = compiled_foo(g1)
219        g1 = g1.clone().detach()
220        eager_result = compiled_foo(g1)
221        self.assertEqual(cmp_result, eager_result)
222        # Recompile, alias changed
223        self.assertEqual(cnt.frame_count, 2)
224
225    def test_dynamic_shape_parameter_recompile(self):
226        # Test the matrix multiplication with Parameters.
227        # Without the config assume_parameters_shapes_static_by_default,
228        # the torch.nn.Parameter shapes are assumed to be static which leads to recompilation
229
230        w = torch.nn.Parameter(torch.randn(3, 2))
231
232        def foo(x):
233            return x @ w
234
235        def run_foo_6_times_and_count_recompiles():
236            cnt = torch._dynamo.testing.CompileCounter()
237
238            opt = torch._dynamo.optimize(cnt, nopython=True)(foo)
239
240            x = torch.nn.Parameter(torch.randn(1, 3))
241            opt(x)
242            x = torch.nn.Parameter(torch.randn(10, 3))
243            opt(x)
244            x = torch.nn.Parameter(torch.randn(11, 3))
245            opt(x)
246            x = torch.nn.Parameter(torch.randn(15, 3))
247            opt(x)
248            x = torch.nn.Parameter(torch.randn(15, 3))
249            opt(x)
250
251            return cnt
252
253        @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
254        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
255        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
256        def run_static_comp_default_param():
257            return run_foo_6_times_and_count_recompiles()
258
259        @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
260        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
261        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
262        def run_dynamic_comp_default_param():
263            return run_foo_6_times_and_count_recompiles()
264
265        @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
266        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
267        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
268        def run_static_comp_dynamic_param():
269            return run_foo_6_times_and_count_recompiles()
270
271        @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
272        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
273        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
274        def run_dynamic_comp_dynamic_param():
275            return run_foo_6_times_and_count_recompiles()
276
277        torch._dynamo.reset()
278        static_comp_default_param = run_static_comp_default_param()
279        self.assertEqual(static_comp_default_param.frame_count, 4)
280        self.assertEqual(static_comp_default_param.op_count, 4)
281
282        torch._dynamo.reset()
283        dynamic_comp_default_param = run_dynamic_comp_default_param()
284        self.assertEqual(dynamic_comp_default_param.frame_count, 4)
285        self.assertEqual(dynamic_comp_default_param.op_count, 4)
286
287        torch._dynamo.reset()
288        static_comp_dynamic_param = run_static_comp_dynamic_param()
289        self.assertEqual(static_comp_dynamic_param.frame_count, 4)
290        self.assertEqual(static_comp_dynamic_param.op_count, 4)
291
292        torch._dynamo.reset()
293        dynamic_comp_dynamic_param = run_dynamic_comp_dynamic_param()
294        self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2)
295        self.assertEqual(dynamic_comp_dynamic_param.op_count, 2)
296
297    def test_simple_module_recompile(self):
298        class SimpleDropout(torch.nn.Module):
299            def __init__(self) -> None:
300                super().__init__()
301                self.dropout = torch.nn.Dropout(0.5)
302                self.linear = torch.nn.Linear(10, 1)
303
304            def forward(self, x):
305                return self.dropout(self.linear(x))
306
307        model = SimpleDropout()
308        x = torch.randn(10)
309        counter = torch._dynamo.testing.CompileCounter()
310        model = torch.compile(model, backend=counter, fullgraph=True)
311        for _ in range(20):
312            model.eval()
313            model(x)
314            model.train()
315            model(x)
316        self.assertEqual(counter.frame_count, 2)
317
318    @patch.object(torch._dynamo.config, "cache_size_limit", 2)
319    def test_no_recursive_compile_after_cache_limit_hit(self):
320        def f(x, n):
321            x = x + n
322            return g(x, n)
323
324        def g(x, n):
325            x = x + n
326            return h(x, n)
327
328        def h(x, n):
329            return x + n
330
331        counter = torch._dynamo.testing.CompileCounter()
332        opt_f = torch.compile(f, backend=counter, dynamic=False)
333        for i in range(10):
334            opt_f(torch.ones(3), i)
335        self.assertEqual(counter.frame_count, 2)
336
337
338if __name__ == "__main__":
339    from torch._dynamo.test_case import run_tests
340
341    run_tests()
342