xref: /aosp_15_r20/external/pytorch/test/dynamo/test_subgraphs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
7*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import unsupported
8*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import ifdynstaticdefault
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerglobalmod = torch.nn.ReLU()
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerdef indirectly_unsupported(a, b):
15*da0073e9SAndroid Build Coastguard Worker    c = a + b
16*da0073e9SAndroid Build Coastguard Worker    return unsupported(a, c)
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerclass SubGraphTests(torch._dynamo.test_case.TestCase):
20*da0073e9SAndroid Build Coastguard Worker    def _common(self, fn, frame_count, op_count):
21*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
22*da0073e9SAndroid Build Coastguard Worker        v1 = torch.ones(10)
23*da0073e9SAndroid Build Coastguard Worker        v2 = torch.ones(10) * -2.0
24*da0073e9SAndroid Build Coastguard Worker        correct1 = fn(v1, v2)
25*da0073e9SAndroid Build Coastguard Worker        correct2 = fn(v2, v1)
26*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
27*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
28*da0073e9SAndroid Build Coastguard Worker        r1 = opt_fn(v1, v2)
29*da0073e9SAndroid Build Coastguard Worker        r2 = opt_fn(v2, v1)
30*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(r1, correct1))
31*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(r2, correct2))
32*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
33*da0073e9SAndroid Build Coastguard Worker            cnt.frame_count,
34*da0073e9SAndroid Build Coastguard Worker            frame_count,
35*da0073e9SAndroid Build Coastguard Worker            f"actual {cnt.frame_count} != expected {frame_count}",
36*da0073e9SAndroid Build Coastguard Worker        )
37*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, op_count)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    def test_control_flow1(self):
40*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
41*da0073e9SAndroid Build Coastguard Worker            c1 = a - b
42*da0073e9SAndroid Build Coastguard Worker            c2 = b - a
43*da0073e9SAndroid Build Coastguard Worker            if c1.sum() > c2.sum():
44*da0073e9SAndroid Build Coastguard Worker                return c1
45*da0073e9SAndroid Build Coastguard Worker            else:
46*da0073e9SAndroid Build Coastguard Worker                return c2
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 1, 5)
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    def test_control_flow2(self):
51*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
52*da0073e9SAndroid Build Coastguard Worker            if a.sum() > b.sum():
53*da0073e9SAndroid Build Coastguard Worker                return 1
54*da0073e9SAndroid Build Coastguard Worker            else:
55*da0073e9SAndroid Build Coastguard Worker                return 2
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 1, 3)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def test_control_flow3(self):
60*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
61*da0073e9SAndroid Build Coastguard Worker            c1 = a - b
62*da0073e9SAndroid Build Coastguard Worker            c2 = b - a
63*da0073e9SAndroid Build Coastguard Worker            m = globalmod
64*da0073e9SAndroid Build Coastguard Worker            if c1.sum() > c2.sum():
65*da0073e9SAndroid Build Coastguard Worker                return m(c1)
66*da0073e9SAndroid Build Coastguard Worker            else:
67*da0073e9SAndroid Build Coastguard Worker                return m(c2)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 3, 7)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def test_control_flow4(self):
72*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
73*da0073e9SAndroid Build Coastguard Worker            tmp1 = a.sum() > b.sum() and a.sum() > 0
74*da0073e9SAndroid Build Coastguard Worker            if tmp1:
75*da0073e9SAndroid Build Coastguard Worker                return 1
76*da0073e9SAndroid Build Coastguard Worker            else:
77*da0073e9SAndroid Build Coastguard Worker                return 2
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 3, 5)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    def test_control_flow5(self):
82*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
83*da0073e9SAndroid Build Coastguard Worker            tmp1 = a.sum() > b.sum() and a.sum() > 0
84*da0073e9SAndroid Build Coastguard Worker            tmp2 = a.sum() < b.sum() or b.sum() > 0
85*da0073e9SAndroid Build Coastguard Worker            if tmp1 and tmp2:
86*da0073e9SAndroid Build Coastguard Worker                return 1, tmp1, tmp2
87*da0073e9SAndroid Build Coastguard Worker            else:
88*da0073e9SAndroid Build Coastguard Worker                return 2, tmp1, tmp2
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 6, 13)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker    def test_capi_call1(self):
93*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
94*da0073e9SAndroid Build Coastguard Worker            c1 = a - b
95*da0073e9SAndroid Build Coastguard Worker            c2 = b - a
96*da0073e9SAndroid Build Coastguard Worker            return unsupported(c1, c2)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 1, 2)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    def test_capi_call2(self):
101*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
102*da0073e9SAndroid Build Coastguard Worker            c1 = a - b
103*da0073e9SAndroid Build Coastguard Worker            c2 = b - a
104*da0073e9SAndroid Build Coastguard Worker            return a - (b - unsupported(c1, c2))
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 4)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    def test_capi_call3(self):
109*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
110*da0073e9SAndroid Build Coastguard Worker            c1 = a - b
111*da0073e9SAndroid Build Coastguard Worker            c2 = b - a
112*da0073e9SAndroid Build Coastguard Worker            return torch._dynamo.testing.unsupported(c1, c2)
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 1, 2)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def test_indirect_unsupported1(self):
117*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
118*da0073e9SAndroid Build Coastguard Worker            c1 = a - b
119*da0073e9SAndroid Build Coastguard Worker            c2 = b - a
120*da0073e9SAndroid Build Coastguard Worker            return indirectly_unsupported(c1, c2)
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 3)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    def test_indirect_unsupported2(self):
125*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
126*da0073e9SAndroid Build Coastguard Worker            local_const1 = 7
127*da0073e9SAndroid Build Coastguard Worker            local_const2 = 22
128*da0073e9SAndroid Build Coastguard Worker            c1 = a - b
129*da0073e9SAndroid Build Coastguard Worker            c2 = b - a
130*da0073e9SAndroid Build Coastguard Worker            return local_const1 / (local_const2 - indirectly_unsupported(c1, c2))
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 3, 5)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    def test_indirect_unsupported3(self):
135*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
136*da0073e9SAndroid Build Coastguard Worker            args = [a - b, b - a]
137*da0073e9SAndroid Build Coastguard Worker            return indirectly_unsupported(*args)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 3)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    def test_stack_state1(self):
142*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
143*da0073e9SAndroid Build Coastguard Worker            t1 = 1.23 * a
144*da0073e9SAndroid Build Coastguard Worker            t2 = 4.56 * a
145*da0073e9SAndroid Build Coastguard Worker            c1 = a - b
146*da0073e9SAndroid Build Coastguard Worker            c2 = b - a
147*da0073e9SAndroid Build Coastguard Worker            return t1 / (t2 - unsupported(c1, c2))
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 6)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    def test_stack_state2(self):
152*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
153*da0073e9SAndroid Build Coastguard Worker            t1 = 1.23 * a
154*da0073e9SAndroid Build Coastguard Worker            t2 = 4.56 * a
155*da0073e9SAndroid Build Coastguard Worker            c1 = a - b
156*da0073e9SAndroid Build Coastguard Worker            c2 = b - a
157*da0073e9SAndroid Build Coastguard Worker            return t1 / (t2 - indirectly_unsupported(c1, c2))
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 3, 7)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def test_multigraph(self):
162*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
163*da0073e9SAndroid Build Coastguard Worker            x = a + b
164*da0073e9SAndroid Build Coastguard Worker            x = x / 2.0
165*da0073e9SAndroid Build Coastguard Worker            if x.sum() < 0:
166*da0073e9SAndroid Build Coastguard Worker                return x * -1.0
167*da0073e9SAndroid Build Coastguard Worker            return x
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 5)
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    def test_extended_args(self):
172*da0073e9SAndroid Build Coastguard Worker        too_many_adds = "+".join(["a", "b"] * 256)
173*da0073e9SAndroid Build Coastguard Worker        source = (
174*da0073e9SAndroid Build Coastguard Worker            f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)"
175*da0073e9SAndroid Build Coastguard Worker        )
176*da0073e9SAndroid Build Coastguard Worker        self._common(eval(source), 3, 1026)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    def test_resume1(self):
179*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
180*da0073e9SAndroid Build Coastguard Worker            x = a + b
181*da0073e9SAndroid Build Coastguard Worker            x = x / 2.0
182*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
183*da0073e9SAndroid Build Coastguard Worker            x = unsupported(x, a)
184*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
185*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
186*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
187*da0073e9SAndroid Build Coastguard Worker            return x
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 6)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    def test_resume2(self):
192*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
193*da0073e9SAndroid Build Coastguard Worker            x = a + b
194*da0073e9SAndroid Build Coastguard Worker            x = x / 2.0
195*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
196*da0073e9SAndroid Build Coastguard Worker            x = indirectly_unsupported(x, a)
197*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
198*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
199*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
200*da0073e9SAndroid Build Coastguard Worker            return x
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 3, 7)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    def test_resume3(self):
205*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
206*da0073e9SAndroid Build Coastguard Worker            x = a + b
207*da0073e9SAndroid Build Coastguard Worker            x = x / 2.0
208*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
209*da0073e9SAndroid Build Coastguard Worker            x = indirectly_unsupported(x, b=a)
210*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
211*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
212*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
213*da0073e9SAndroid Build Coastguard Worker            return x
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 3, 7)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker    def test_resume4(self):
218*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
219*da0073e9SAndroid Build Coastguard Worker            x = a + b
220*da0073e9SAndroid Build Coastguard Worker            x = x / 2.0
221*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
222*da0073e9SAndroid Build Coastguard Worker            x = indirectly_unsupported(a=x, b=a)
223*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
224*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
225*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
226*da0073e9SAndroid Build Coastguard Worker            return x
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 3, 7)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    def test_resume5(self):
231*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
232*da0073e9SAndroid Build Coastguard Worker            x = a + b
233*da0073e9SAndroid Build Coastguard Worker            x = x / 2.0
234*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
235*da0073e9SAndroid Build Coastguard Worker            print(x)
236*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
237*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
238*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
239*da0073e9SAndroid Build Coastguard Worker            return x
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 6)
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    def test_start1(self):
244*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
245*da0073e9SAndroid Build Coastguard Worker            print(a)
246*da0073e9SAndroid Build Coastguard Worker            x = a + b
247*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
248*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
249*da0073e9SAndroid Build Coastguard Worker            return x
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 1, 3)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    def test_start2(self):
254*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
255*da0073e9SAndroid Build Coastguard Worker            x = indirectly_unsupported(a, b)
256*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
257*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
258*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
259*da0073e9SAndroid Build Coastguard Worker            return x
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 4)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker    def test_start3(self):
264*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
265*da0073e9SAndroid Build Coastguard Worker            x = unsupported(a, b)
266*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
267*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
268*da0073e9SAndroid Build Coastguard Worker            x = x + 2.0
269*da0073e9SAndroid Build Coastguard Worker            return x
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 1, 3)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    def test_start4(self):
274*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, check):
275*da0073e9SAndroid Build Coastguard Worker            if check:
276*da0073e9SAndroid Build Coastguard Worker                return a + b + 10
277*da0073e9SAndroid Build Coastguard Worker            else:
278*da0073e9SAndroid Build Coastguard Worker                return a + b - 10
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        v1 = torch.randn(10)
281*da0073e9SAndroid Build Coastguard Worker        v2 = torch.randn(10)
282*da0073e9SAndroid Build Coastguard Worker        f = torch.zeros(1, dtype=torch.int32)
283*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(1, dtype=torch.int32)
284*da0073e9SAndroid Build Coastguard Worker        correct1 = fn(v1, v2, t)
285*da0073e9SAndroid Build Coastguard Worker        correct2 = fn(v1, v2, f)
286*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
287*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
288*da0073e9SAndroid Build Coastguard Worker        r1 = opt_fn(v1, v2, t)
289*da0073e9SAndroid Build Coastguard Worker        r2 = opt_fn(v1, v2, f)
290*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(r1, correct1))
291*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._dynamo.testing.same(r2, correct2))
292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 3)
293*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 4)
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    def test_resume_freevars(self):
296*da0073e9SAndroid Build Coastguard Worker        c1 = torch.randn(10)
297*da0073e9SAndroid Build Coastguard Worker        c2 = torch.randn(10)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
300*da0073e9SAndroid Build Coastguard Worker            x = a + b + (c1 - c2)
301*da0073e9SAndroid Build Coastguard Worker            x = unsupported(x, x)
302*da0073e9SAndroid Build Coastguard Worker            return x + (c1 - c2)
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 5)
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker    def test_restore_state(self):
307*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
308*da0073e9SAndroid Build Coastguard Worker            len_ = len
309*da0073e9SAndroid Build Coastguard Worker            x = a + b
310*da0073e9SAndroid Build Coastguard Worker            x = torch.add(unsupported(x, x), 1)
311*da0073e9SAndroid Build Coastguard Worker            return a * x + len_(b)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 4)
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker    def test_restore_range(self):
316*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
317*da0073e9SAndroid Build Coastguard Worker            x = a + b
318*da0073e9SAndroid Build Coastguard Worker            rng = range(3, 8, 2)
319*da0073e9SAndroid Build Coastguard Worker            x = unsupported(x, x)
320*da0073e9SAndroid Build Coastguard Worker            for i in rng:
321*da0073e9SAndroid Build Coastguard Worker                x = x + i
322*da0073e9SAndroid Build Coastguard Worker            return x
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        # We don't specialize on range with dynamic shapes, which
325*da0073e9SAndroid Build Coastguard Worker        # means we fail to unroll the loop.
326*da0073e9SAndroid Build Coastguard Worker        # TODO: Consider forcing specialization when we iterate over
327*da0073e9SAndroid Build Coastguard Worker        # the loop
328*da0073e9SAndroid Build Coastguard Worker        self._common(fn, ifdynstaticdefault(2, 1), ifdynstaticdefault(4, 1))
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker    def test_restore_range_iter(self):
331*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
332*da0073e9SAndroid Build Coastguard Worker            x = a + b
333*da0073e9SAndroid Build Coastguard Worker            rng = iter(range(3, 8, 2))
334*da0073e9SAndroid Build Coastguard Worker            x = unsupported(x, x)
335*da0073e9SAndroid Build Coastguard Worker            x += next(rng)
336*da0073e9SAndroid Build Coastguard Worker            return x, list(rng)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 2)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    def test_pop_after_resume(self):
341*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
342*da0073e9SAndroid Build Coastguard Worker            tmp = [a + 1, b + 2, a + b]
343*da0073e9SAndroid Build Coastguard Worker            x = a
344*da0073e9SAndroid Build Coastguard Worker            x = unsupported(x, x)
345*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
346*da0073e9SAndroid Build Coastguard Worker                x += tmp.pop(-1)
347*da0073e9SAndroid Build Coastguard Worker            return x
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 6)
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker    @patch("torch._dynamo.config.assume_static_by_default", False)
352*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_getitem(self):
353*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
354*da0073e9SAndroid Build Coastguard Worker            return a[b.size(0) - 1]
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
357*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
358*da0073e9SAndroid Build Coastguard Worker        for i in range(3, 12):
359*da0073e9SAndroid Build Coastguard Worker            opt_fn(torch.randn(i), torch.randn(i))
360*da0073e9SAndroid Build Coastguard Worker        # just one graph
361*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_kwarg(self):
364*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
365*da0073e9SAndroid Build Coastguard Worker            return a - b * 10
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
368*da0073e9SAndroid Build Coastguard Worker        cnt_dynamic = torch._dynamo.testing.CompileCounter()
369*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
370*da0073e9SAndroid Build Coastguard Worker        start = 2
371*da0073e9SAndroid Build Coastguard Worker        end = 12
372*da0073e9SAndroid Build Coastguard Worker        steps = end - start
373*da0073e9SAndroid Build Coastguard Worker        for i in range(start, end):
374*da0073e9SAndroid Build Coastguard Worker            opt_fn(torch.randn(i), torch.randn(i))
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt_dynamic.frame_count, 1)
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_duck_size(self):
379*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
380*da0073e9SAndroid Build Coastguard Worker            if a.size(0) == b.size(0):
381*da0073e9SAndroid Build Coastguard Worker                return a + b
382*da0073e9SAndroid Build Coastguard Worker            else:
383*da0073e9SAndroid Build Coastguard Worker                return a.sum() + b.sum()
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
386*da0073e9SAndroid Build Coastguard Worker        cnt_dynamic = torch._dynamo.testing.CompileCounter()
387*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
388*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2)
389*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, x), fn(x, x))
391*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, y), fn(x, y))
392*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt_dynamic.frame_count, 2)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_order_dependence(self):
395*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
396*da0073e9SAndroid Build Coastguard Worker            return a.sum() + b.sum()
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
399*da0073e9SAndroid Build Coastguard Worker        cnt_dynamic = torch._dynamo.testing.CompileCounter()
400*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
401*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2)
402*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, y), fn(x, y))
404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, x), fn(x, x))
405*da0073e9SAndroid Build Coastguard Worker        # NB: This COULD validly be 2, but we don't test disjointness in the
406*da0073e9SAndroid Build Coastguard Worker        # guards for when x and y didn't duck size together, so we end up
407*da0073e9SAndroid Build Coastguard Worker        # with a generic graph that also works when x and y happen to duck
408*da0073e9SAndroid Build Coastguard Worker        # size together.
409*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt_dynamic.frame_count, 2)
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
412*da0073e9SAndroid Build Coastguard Worker        cnt_dynamic.frame_count = 0
413*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, x), fn(x, x))  # this overspecializes!
414*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, y), fn(x, y))
415*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt_dynamic.frame_count, 2)
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_zero_inference(self):
418*da0073e9SAndroid Build Coastguard Worker        def fn(a):
419*da0073e9SAndroid Build Coastguard Worker            if a.size(0) != 0:
420*da0073e9SAndroid Build Coastguard Worker                return a * 2
421*da0073e9SAndroid Build Coastguard Worker            else:
422*da0073e9SAndroid Build Coastguard Worker                return a + 1
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
425*da0073e9SAndroid Build Coastguard Worker        cnt_dynamic = torch._dynamo.testing.CompileCounter()
426*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
427*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(0)
428*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2)
429*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(y), fn(y))
430*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x), fn(x))
431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt_dynamic.frame_count, 2)
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
434*da0073e9SAndroid Build Coastguard Worker    def test_no_graph_break_on_item(self):
435*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
436*da0073e9SAndroid Build Coastguard Worker            x = a + b - 1.5
437*da0073e9SAndroid Build Coastguard Worker            x = x.sum()
438*da0073e9SAndroid Build Coastguard Worker            x.item()
439*da0073e9SAndroid Build Coastguard Worker            x = x / (a + b)
440*da0073e9SAndroid Build Coastguard Worker            return x
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 1, 5)  # item gets DCE'd
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
445*da0073e9SAndroid Build Coastguard Worker    def test_graph_break_on_item(self):
446*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
447*da0073e9SAndroid Build Coastguard Worker            x = a + b - 1.5
448*da0073e9SAndroid Build Coastguard Worker            x = x.sum()
449*da0073e9SAndroid Build Coastguard Worker            x.item()
450*da0073e9SAndroid Build Coastguard Worker            x = x / (a + b)
451*da0073e9SAndroid Build Coastguard Worker            return x
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 5)
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker    def test_resume_paths_join(self):
456*da0073e9SAndroid Build Coastguard Worker        def fn(x, c1, c2, c3):
457*da0073e9SAndroid Build Coastguard Worker            x = x + 1
458*da0073e9SAndroid Build Coastguard Worker            if c1:
459*da0073e9SAndroid Build Coastguard Worker                x = x + 2
460*da0073e9SAndroid Build Coastguard Worker            x = x + 3
461*da0073e9SAndroid Build Coastguard Worker            if c2:
462*da0073e9SAndroid Build Coastguard Worker                x = x + 4
463*da0073e9SAndroid Build Coastguard Worker            x = x + 5
464*da0073e9SAndroid Build Coastguard Worker            if c3:
465*da0073e9SAndroid Build Coastguard Worker                x = x + 6
466*da0073e9SAndroid Build Coastguard Worker            return x + 7
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker        v1 = torch.randn(10)
469*da0073e9SAndroid Build Coastguard Worker        t = torch.Tensor([True])
470*da0073e9SAndroid Build Coastguard Worker        f = torch.Tensor([False])
471*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
472*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
473*da0073e9SAndroid Build Coastguard Worker        for a in (t, f):
474*da0073e9SAndroid Build Coastguard Worker            for b in (t, f):
475*da0073e9SAndroid Build Coastguard Worker                for c in (t, f):
476*da0073e9SAndroid Build Coastguard Worker                    opt_fn(v1, a, b, c)
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker        # checking here we don't create 2^n graphs
479*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 7)
480*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 10)
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker    def test_resume_with_no_grad1(self):
483*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
484*da0073e9SAndroid Build Coastguard Worker            x = a + b
485*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
486*da0073e9SAndroid Build Coastguard Worker                x = x + 1
487*da0073e9SAndroid Build Coastguard Worker                x.sum().tolist()  # graph break
488*da0073e9SAndroid Build Coastguard Worker                x = x + 2
489*da0073e9SAndroid Build Coastguard Worker            x = x + 3
490*da0073e9SAndroid Build Coastguard Worker            return x
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 9)
493*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
494*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
495*da0073e9SAndroid Build Coastguard Worker            self._common(fn, 2, 5)
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker    def test_resume_with_no_grad2(self):
498*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
499*da0073e9SAndroid Build Coastguard Worker            x = a + b
500*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
501*da0073e9SAndroid Build Coastguard Worker                x = x + 1
502*da0073e9SAndroid Build Coastguard Worker                x.sum().tolist()  # graph break
503*da0073e9SAndroid Build Coastguard Worker                x = x + 2
504*da0073e9SAndroid Build Coastguard Worker                x.sum().tolist()  # graph break
505*da0073e9SAndroid Build Coastguard Worker                x = x + 3
506*da0073e9SAndroid Build Coastguard Worker            x = x + 4
507*da0073e9SAndroid Build Coastguard Worker            return x
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 3, 13)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    def test_resume_with_no_grad3(self):
512*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
513*da0073e9SAndroid Build Coastguard Worker            x = a + b
514*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
515*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
516*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
517*da0073e9SAndroid Build Coastguard Worker                    with torch.enable_grad():
518*da0073e9SAndroid Build Coastguard Worker                        x.sum().tolist()  # graph break
519*da0073e9SAndroid Build Coastguard Worker                        x = x[0] + 2
520*da0073e9SAndroid Build Coastguard Worker                    x = x + 3
521*da0073e9SAndroid Build Coastguard Worker            x = x + 4
522*da0073e9SAndroid Build Coastguard Worker            return x
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 11)
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker    def test_resume_tuple_iterator(self):
527*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
528*da0073e9SAndroid Build Coastguard Worker            x = a + b
529*da0073e9SAndroid Build Coastguard Worker            it = iter(tuple(range(10)))
530*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
531*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
532*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
533*da0073e9SAndroid Build Coastguard Worker            x = unsupported(x, x)
534*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
535*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
536*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
537*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
538*da0073e9SAndroid Build Coastguard Worker            return x
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 2, 8)
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker    def test_tuple_iterator_return(self):
543*da0073e9SAndroid Build Coastguard Worker        def fn(x):
544*da0073e9SAndroid Build Coastguard Worker            it = iter(tuple(range(10)))
545*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
546*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
547*da0073e9SAndroid Build Coastguard Worker            x = unsupported(x, x)
548*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
549*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
550*da0073e9SAndroid Build Coastguard Worker            x = unsupported(x, x)
551*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
552*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
553*da0073e9SAndroid Build Coastguard Worker            return x, it
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker        v1 = torch.randn(10)
556*da0073e9SAndroid Build Coastguard Worker        v2, it2 = fn(v1)
557*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
558*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
559*da0073e9SAndroid Build Coastguard Worker        v3, it3 = opt_fn(v1)
560*da0073e9SAndroid Build Coastguard Worker        v4, it4 = opt_fn(v1)
561*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v2.tolist(), v3.tolist())
562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v2.tolist(), v4.tolist())
563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(it2), list(it3))
564*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 3)
565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 6)
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker    def test_tuple_iterator_mutate(self):
568*da0073e9SAndroid Build Coastguard Worker        def fn(x, it):
569*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
570*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
571*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
572*da0073e9SAndroid Build Coastguard Worker            x = x + next(it)
573*da0073e9SAndroid Build Coastguard Worker            return x
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker        v1 = torch.randn(10)
576*da0073e9SAndroid Build Coastguard Worker        it1 = iter(tuple(range(10)))
577*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
578*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
579*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist())
580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9])
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    def test_enumerate_not_break_graph(self):
583*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
584*da0073e9SAndroid Build Coastguard Worker            for i, x in enumerate(a.shape):
585*da0073e9SAndroid Build Coastguard Worker                b = b + x
586*da0073e9SAndroid Build Coastguard Worker            for i, x in enumerate(b.shape, 8):
587*da0073e9SAndroid Build Coastguard Worker                b = b + x * i
588*da0073e9SAndroid Build Coastguard Worker            return b
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker        self._common(fn, 1, ifdynstaticdefault(2, 3))
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
594*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker    run_tests()
597