xref: /aosp_15_r20/external/pytorch/test/dynamo/test_recompiles.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerclass RecompileTests(torch._dynamo.test_case.TestCase):
10*da0073e9SAndroid Build Coastguard Worker    def test_automatic_dynamic_reduce_recompiles(self):
11*da0073e9SAndroid Build Coastguard Worker        # Test the counterfactual, lots of recompiles without this config
12*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
13*da0073e9SAndroid Build Coastguard Worker            return x * y
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker        def run_foo_6_times_and_count_recompiles(dynamic=None):
16*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounter()
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([2])
19*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([2])
20*da0073e9SAndroid Build Coastguard Worker            opt = torch._dynamo.optimize(cnt, dynamic=dynamic)(foo)
21*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
22*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([3])
23*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([3])
24*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
25*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([4])
26*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([4])
27*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
28*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
29*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([5])
30*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([5])
31*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
32*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
33*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([6])
34*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([6])
35*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker            return cnt
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
40*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
41*da0073e9SAndroid Build Coastguard Worker        def run_without_automatic():
42*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles()
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
45*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
46*da0073e9SAndroid Build Coastguard Worker        def run_with_automatic():
47*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles()
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        without = run_without_automatic()
50*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(without.frame_count, 5)
51*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(without.op_count, 5)
52*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
53*da0073e9SAndroid Build Coastguard Worker        without = run_foo_6_times_and_count_recompiles(dynamic=False)
54*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(without.frame_count, 5)
55*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(without.op_count, 5)
56*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
57*da0073e9SAndroid Build Coastguard Worker        with_automatic = run_with_automatic()
58*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_automatic.frame_count, 2)
59*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_automatic.op_count, 2)
60*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
61*da0073e9SAndroid Build Coastguard Worker        with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None)
62*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_automatic.frame_count, 2)
63*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_automatic.op_count, 2)
64*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
65*da0073e9SAndroid Build Coastguard Worker        with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True)
66*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_dynamic.frame_count, 1)
67*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_dynamic.op_count, 1)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "assume_static_by_default", True)
70*da0073e9SAndroid Build Coastguard Worker    def test_recompiles_true_false_flop(self):
71*da0073e9SAndroid Build Coastguard Worker        # Test the counterfactual, lots of recompiles without this config
72*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
73*da0073e9SAndroid Build Coastguard Worker            if x:
74*da0073e9SAndroid Build Coastguard Worker                return y * 2
75*da0073e9SAndroid Build Coastguard Worker            else:
76*da0073e9SAndroid Build Coastguard Worker                return y * y
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        def run_foo_6_times_and_count_recompiles():
79*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounter()
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker            opt = torch._dynamo.optimize(cnt, nopython=True)(foo)
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker            x = True
84*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([2])
85*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
86*da0073e9SAndroid Build Coastguard Worker            x = False
87*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([2])
88*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
89*da0073e9SAndroid Build Coastguard Worker            x = True
90*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([3])
91*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
92*da0073e9SAndroid Build Coastguard Worker            x = True
93*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([4])
94*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
95*da0073e9SAndroid Build Coastguard Worker            x = True
96*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([5])
97*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker            return cnt
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
102*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
103*da0073e9SAndroid Build Coastguard Worker        def run_without_automatic():
104*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles()
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
107*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
108*da0073e9SAndroid Build Coastguard Worker        def run_with_automatic():
109*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles()
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker        without = run_without_automatic()
112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(without.frame_count, 5)
113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(without.op_count, 5)
114*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
115*da0073e9SAndroid Build Coastguard Worker        with_automatic = run_with_automatic()
116*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_automatic.frame_count, 3)
117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_automatic.op_count, 3)
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    def test_automatic_dynamic_tensor_scalar_change(self):
120*da0073e9SAndroid Build Coastguard Worker        # Test the counterfactual, lots of recompiles without this config
121*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
122*da0073e9SAndroid Build Coastguard Worker            return x * y
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        def run_foo_6_times_and_count_recompiles_swap_types():
125*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounter()
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([2])
128*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([2])
129*da0073e9SAndroid Build Coastguard Worker            opt = torch._dynamo.optimize(cnt)(foo)
130*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
131*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([3])
132*da0073e9SAndroid Build Coastguard Worker            y = 3
133*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
134*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([4])
135*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([4])
136*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
137*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
138*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([5])
139*da0073e9SAndroid Build Coastguard Worker            y = 4
140*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
141*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
142*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([6])
143*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([6])
144*da0073e9SAndroid Build Coastguard Worker            opt(x, y)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker            return cnt
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
149*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
150*da0073e9SAndroid Build Coastguard Worker        def run_without_automatic():
151*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles_swap_types()
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
154*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
155*da0073e9SAndroid Build Coastguard Worker        def run_with_automatic():
156*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles_swap_types()
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        without = run_without_automatic()
159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(without.frame_count, 5)
160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(without.op_count, 5)
161*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
162*da0073e9SAndroid Build Coastguard Worker        with_automatic = run_with_automatic()
163*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_automatic.frame_count, 3)
164*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_automatic.op_count, 3)
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    def test_aliasing_guard_failures(self):
167*da0073e9SAndroid Build Coastguard Worker        def foo(a, b, c):
168*da0073e9SAndroid Build Coastguard Worker            a.add_(b)
169*da0073e9SAndroid Build Coastguard Worker            return c + 1
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
172*da0073e9SAndroid Build Coastguard Worker        compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([3])
175*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([3])
176*da0073e9SAndroid Build Coastguard Worker        z = torch.randn([3])
177*da0073e9SAndroid Build Coastguard Worker        cmp_result = compiled_foo(
178*da0073e9SAndroid Build Coastguard Worker            x.clone().detach(), y.clone().detach(), z.clone().detach()
179*da0073e9SAndroid Build Coastguard Worker        )
180*da0073e9SAndroid Build Coastguard Worker        eager_result = foo(x.clone().detach(), y.clone().detach(), z.clone().detach())
181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cmp_result, eager_result)
182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        cmp_result = compiled_foo(
185*da0073e9SAndroid Build Coastguard Worker            z.clone().detach(), y.clone().detach(), x.clone().detach()
186*da0073e9SAndroid Build Coastguard Worker        )
187*da0073e9SAndroid Build Coastguard Worker        eager_result = foo(z.clone().detach(), y.clone().detach(), x.clone().detach())
188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cmp_result, eager_result)
189*da0073e9SAndroid Build Coastguard Worker        # No recompile, alias preserved
190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker        x_clone = x.clone().detach()
193*da0073e9SAndroid Build Coastguard Worker        cmp_result = compiled_foo(x_clone, y.clone().detach(), x_clone)
194*da0073e9SAndroid Build Coastguard Worker        x_clone = x.clone().detach()
195*da0073e9SAndroid Build Coastguard Worker        eager_result = compiled_foo(x_clone, y.clone().detach(), x_clone)
196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cmp_result, eager_result)
197*da0073e9SAndroid Build Coastguard Worker        # Recompile, alias changed
198*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker    def test_aliasing_guard_failures_with_globals(self):
201*da0073e9SAndroid Build Coastguard Worker        g1 = torch.randn([3])
202*da0073e9SAndroid Build Coastguard Worker        g2 = torch.randn([3])
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        def foo(a):
205*da0073e9SAndroid Build Coastguard Worker            a.add_(g1)
206*da0073e9SAndroid Build Coastguard Worker            return g2 + 1
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
209*da0073e9SAndroid Build Coastguard Worker        compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        z = torch.randn([3])
212*da0073e9SAndroid Build Coastguard Worker        cmp_result = compiled_foo(z.clone().detach())
213*da0073e9SAndroid Build Coastguard Worker        eager_result = foo(z.clone().detach())
214*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cmp_result, eager_result)
215*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        g1 = g1.clone().detach()
218*da0073e9SAndroid Build Coastguard Worker        cmp_result = compiled_foo(g1)
219*da0073e9SAndroid Build Coastguard Worker        g1 = g1.clone().detach()
220*da0073e9SAndroid Build Coastguard Worker        eager_result = compiled_foo(g1)
221*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cmp_result, eager_result)
222*da0073e9SAndroid Build Coastguard Worker        # Recompile, alias changed
223*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_shape_parameter_recompile(self):
226*da0073e9SAndroid Build Coastguard Worker        # Test the matrix multiplication with Parameters.
227*da0073e9SAndroid Build Coastguard Worker        # Without the config assume_parameters_shapes_static_by_default,
228*da0073e9SAndroid Build Coastguard Worker        # the torch.nn.Parameter shapes are assumed to be static which leads to recompilation
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        w = torch.nn.Parameter(torch.randn(3, 2))
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker        def foo(x):
233*da0073e9SAndroid Build Coastguard Worker            return x @ w
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        def run_foo_6_times_and_count_recompiles():
236*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounter()
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker            opt = torch._dynamo.optimize(cnt, nopython=True)(foo)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker            x = torch.nn.Parameter(torch.randn(1, 3))
241*da0073e9SAndroid Build Coastguard Worker            opt(x)
242*da0073e9SAndroid Build Coastguard Worker            x = torch.nn.Parameter(torch.randn(10, 3))
243*da0073e9SAndroid Build Coastguard Worker            opt(x)
244*da0073e9SAndroid Build Coastguard Worker            x = torch.nn.Parameter(torch.randn(11, 3))
245*da0073e9SAndroid Build Coastguard Worker            opt(x)
246*da0073e9SAndroid Build Coastguard Worker            x = torch.nn.Parameter(torch.randn(15, 3))
247*da0073e9SAndroid Build Coastguard Worker            opt(x)
248*da0073e9SAndroid Build Coastguard Worker            x = torch.nn.Parameter(torch.randn(15, 3))
249*da0073e9SAndroid Build Coastguard Worker            opt(x)
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker            return cnt
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
254*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
255*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
256*da0073e9SAndroid Build Coastguard Worker        def run_static_comp_default_param():
257*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles()
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
260*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
261*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
262*da0073e9SAndroid Build Coastguard Worker        def run_dynamic_comp_default_param():
263*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles()
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
266*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
267*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
268*da0073e9SAndroid Build Coastguard Worker        def run_static_comp_dynamic_param():
269*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles()
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
272*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
273*da0073e9SAndroid Build Coastguard Worker        @patch.object(torch._dynamo.config, "assume_static_by_default", True)
274*da0073e9SAndroid Build Coastguard Worker        def run_dynamic_comp_dynamic_param():
275*da0073e9SAndroid Build Coastguard Worker            return run_foo_6_times_and_count_recompiles()
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
278*da0073e9SAndroid Build Coastguard Worker        static_comp_default_param = run_static_comp_default_param()
279*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(static_comp_default_param.frame_count, 4)
280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(static_comp_default_param.op_count, 4)
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
283*da0073e9SAndroid Build Coastguard Worker        dynamic_comp_default_param = run_dynamic_comp_default_param()
284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dynamic_comp_default_param.frame_count, 4)
285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dynamic_comp_default_param.op_count, 4)
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
288*da0073e9SAndroid Build Coastguard Worker        static_comp_dynamic_param = run_static_comp_dynamic_param()
289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(static_comp_dynamic_param.frame_count, 4)
290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(static_comp_dynamic_param.op_count, 4)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
293*da0073e9SAndroid Build Coastguard Worker        dynamic_comp_dynamic_param = run_dynamic_comp_dynamic_param()
294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2)
295*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dynamic_comp_dynamic_param.op_count, 2)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def test_simple_module_recompile(self):
298*da0073e9SAndroid Build Coastguard Worker        class SimpleDropout(torch.nn.Module):
299*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
300*da0073e9SAndroid Build Coastguard Worker                super().__init__()
301*da0073e9SAndroid Build Coastguard Worker                self.dropout = torch.nn.Dropout(0.5)
302*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(10, 1)
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
305*da0073e9SAndroid Build Coastguard Worker                return self.dropout(self.linear(x))
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        model = SimpleDropout()
308*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
309*da0073e9SAndroid Build Coastguard Worker        counter = torch._dynamo.testing.CompileCounter()
310*da0073e9SAndroid Build Coastguard Worker        model = torch.compile(model, backend=counter, fullgraph=True)
311*da0073e9SAndroid Build Coastguard Worker        for _ in range(20):
312*da0073e9SAndroid Build Coastguard Worker            model.eval()
313*da0073e9SAndroid Build Coastguard Worker            model(x)
314*da0073e9SAndroid Build Coastguard Worker            model.train()
315*da0073e9SAndroid Build Coastguard Worker            model(x)
316*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "cache_size_limit", 2)
319*da0073e9SAndroid Build Coastguard Worker    def test_no_recursive_compile_after_cache_limit_hit(self):
320*da0073e9SAndroid Build Coastguard Worker        def f(x, n):
321*da0073e9SAndroid Build Coastguard Worker            x = x + n
322*da0073e9SAndroid Build Coastguard Worker            return g(x, n)
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        def g(x, n):
325*da0073e9SAndroid Build Coastguard Worker            x = x + n
326*da0073e9SAndroid Build Coastguard Worker            return h(x, n)
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker        def h(x, n):
329*da0073e9SAndroid Build Coastguard Worker            return x + n
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        counter = torch._dynamo.testing.CompileCounter()
332*da0073e9SAndroid Build Coastguard Worker        opt_f = torch.compile(f, backend=counter, dynamic=False)
333*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
334*da0073e9SAndroid Build Coastguard Worker            opt_f(torch.ones(3), i)
335*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
339*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    run_tests()
342