xref: /aosp_15_r20/external/pytorch/test/dynamo/test_unspec.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport math
3*da0073e9SAndroid Build Coastguard Workerimport random
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport numpy as np
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.comptime import comptime
13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounter, same
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfWindows
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_utils import logs_to_string
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker# The intention of this test file is you should put test cases specifically
19*da0073e9SAndroid Build Coastguard Worker# for assume_static_by_default=False, aka you want to YOLO make everything as
20*da0073e9SAndroid Build Coastguard Worker# dynamic as possible.  If you want to test the more normal situation where
21*da0073e9SAndroid Build Coastguard Worker# you assume static by default, put it in a regular test file and
22*da0073e9SAndroid Build Coastguard Worker# test_dynamic_shapes will cover both the YOLO and non-YOLO cases.
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.config.patch(assume_static_by_default=False)
26*da0073e9SAndroid Build Coastguard Workerclass UnspecTests(torch._dynamo.test_case.TestCase):
27*da0073e9SAndroid Build Coastguard Worker    def test_numpy_correctness(self):
28*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
29*da0073e9SAndroid Build Coastguard Worker            xy = [x + y, y, False]
30*da0073e9SAndroid Build Coastguard Worker            np_x = x.numpy()
31*da0073e9SAndroid Build Coastguard Worker            np_y = y.numpy()
32*da0073e9SAndroid Build Coastguard Worker            return {
33*da0073e9SAndroid Build Coastguard Worker                "x": x,
34*da0073e9SAndroid Build Coastguard Worker                "z": z,
35*da0073e9SAndroid Build Coastguard Worker                "a": np_y.sum(),
36*da0073e9SAndroid Build Coastguard Worker                "b": xy,
37*da0073e9SAndroid Build Coastguard Worker                "c": np_y[0][0] / 68,
38*da0073e9SAndroid Build Coastguard Worker                "d": np_x.sum(),
39*da0073e9SAndroid Build Coastguard Worker                "e": np_x + np_y,
40*da0073e9SAndroid Build Coastguard Worker            }, x + np_y.sum() + z
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
43*da0073e9SAndroid Build Coastguard Worker        y = torch.ones([2, 2], dtype=torch.int64)
44*da0073e9SAndroid Build Coastguard Worker        z = np.int64(12)
45*da0073e9SAndroid Build Coastguard Worker        res1 = fn(x, y, z)
46*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
47*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
48*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x, y, z)
49*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    def test_no_recompilations(self):
52*da0073e9SAndroid Build Coastguard Worker        # no recompilations if passing on different numpy int values
53*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
54*da0073e9SAndroid Build Coastguard Worker            return {"a": x + 1, "b": y / 2}
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
57*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
58*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
59*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
60*da0073e9SAndroid Build Coastguard Worker            opt_fn(x, np.int64(i))
61*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
62*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure  # array scalars decay to 0D arrays
65*da0073e9SAndroid Build Coastguard Worker    def test_builtin_max_min(self):
66*da0073e9SAndroid Build Coastguard Worker        # test unspecialized primitive max/min
67*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
68*da0073e9SAndroid Build Coastguard Worker            return z + 1, max(x, y), min(x - 4, y)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        x = np.int64(12)
71*da0073e9SAndroid Build Coastguard Worker        y = 10
72*da0073e9SAndroid Build Coastguard Worker        z = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
73*da0073e9SAndroid Build Coastguard Worker        res1 = fn(x, y, z)
74*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
75*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
76*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x, y, z)
77*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2, relax_numpy_equality=True))
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    def test_feed_random_values_into_graph_only(self):
80*da0073e9SAndroid Build Coastguard Worker        def fn(shape):
81*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(123)
82*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device="cpu") * random.randint(30, 100)
83*da0073e9SAndroid Build Coastguard Worker            return x
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        shape = [2, 3]
86*da0073e9SAndroid Build Coastguard Worker        random.seed(1)
87*da0073e9SAndroid Build Coastguard Worker        res1 = fn(shape)
88*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
89*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
90*da0073e9SAndroid Build Coastguard Worker        random.seed(1)
91*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(shape)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2))
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def test_random_values_with_graph_break(self):
96*da0073e9SAndroid Build Coastguard Worker        def fn(x):
97*da0073e9SAndroid Build Coastguard Worker            r1 = random.random()
98*da0073e9SAndroid Build Coastguard Worker            y = x + random.uniform(10, 20)
99*da0073e9SAndroid Build Coastguard Worker            y.sum().item()
100*da0073e9SAndroid Build Coastguard Worker            r2 = random.randint(2, 18)  # no graph output in this frame
101*da0073e9SAndroid Build Coastguard Worker            y.sum().item()
102*da0073e9SAndroid Build Coastguard Worker            return y + r1, r2
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
105*da0073e9SAndroid Build Coastguard Worker        random.seed(1)
106*da0073e9SAndroid Build Coastguard Worker        res1 = fn(x)
107*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
108*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
109*da0073e9SAndroid Build Coastguard Worker        random.seed(1)
110*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x)
111*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2))
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    # Really annoying intersection of specialization and RandomValueSource
114*da0073e9SAndroid Build Coastguard Worker    # If we get a RandomValueSource with a single element tensor, we should return a ConstantVariable like other
115*da0073e9SAndroid Build Coastguard Worker    # unspects... but if we do, we break the bytecode assumptions and guards will not work as we will be referring
116*da0073e9SAndroid Build Coastguard Worker    # to a name from a source that is not there. If we call .item() and take the wrapped_value out, where we do
117*da0073e9SAndroid Build Coastguard Worker    # wrapped_value = wrapped_value.item() where we send unspec down to wrap_fx_proxy, this test passes and then
118*da0073e9SAndroid Build Coastguard Worker    # some models fail on missing codegen.tx.output.random_values_var. If we let the tensor value go into wrap as
119*da0073e9SAndroid Build Coastguard Worker    # it is, this test fails.
120*da0073e9SAndroid Build Coastguard Worker    # The real solution here is to rewrite RandomValueSource and all the codegen it does from the ground up.
121*da0073e9SAndroid Build Coastguard Worker    def test_multiple_consecutive_random_calls_before_graph(self):
122*da0073e9SAndroid Build Coastguard Worker        def fn(x):
123*da0073e9SAndroid Build Coastguard Worker            dim1 = random.randrange(start=0, stop=5)
124*da0073e9SAndroid Build Coastguard Worker            dim2 = random.randrange(start=0, stop=5)
125*da0073e9SAndroid Build Coastguard Worker            dim3 = random.randrange(start=0, stop=5)
126*da0073e9SAndroid Build Coastguard Worker            y = torch.rand(dim1, dim2, dim3)
127*da0073e9SAndroid Build Coastguard Worker            return x + 2, y
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
130*da0073e9SAndroid Build Coastguard Worker        random.seed(1)
131*da0073e9SAndroid Build Coastguard Worker        res1 = fn(x)
132*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
133*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
134*da0073e9SAndroid Build Coastguard Worker        random.seed(1)
135*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x)
136*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2))
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    def test_compiled_random_calls_are_random(self):
139*da0073e9SAndroid Build Coastguard Worker        # For compiled functions with random calls,
140*da0073e9SAndroid Build Coastguard Worker        # it should return different values for every iteration.
141*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/95425
142*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
143*da0073e9SAndroid Build Coastguard Worker        def fn(x):
144*da0073e9SAndroid Build Coastguard Worker            return (x + 1) * random.uniform(0, 1)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        res = []
147*da0073e9SAndroid Build Coastguard Worker        for _ in range(5):
148*da0073e9SAndroid Build Coastguard Worker            res.append(fn(torch.ones(2)))
149*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 5):
150*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(same(res[i - 1], res[i]))
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    def test_random_call_with_while_loop(self):
153*da0073e9SAndroid Build Coastguard Worker        def fn(x):
154*da0073e9SAndroid Build Coastguard Worker            dim1 = random.randrange(start=0, stop=3)
155*da0073e9SAndroid Build Coastguard Worker            dim2 = dim1
156*da0073e9SAndroid Build Coastguard Worker            while dim1 == dim2:
157*da0073e9SAndroid Build Coastguard Worker                dim2 = random.randrange(start=0, stop=3)
158*da0073e9SAndroid Build Coastguard Worker            return x * 2
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
161*da0073e9SAndroid Build Coastguard Worker        random.seed(1)
162*da0073e9SAndroid Build Coastguard Worker        res1 = fn(x)
163*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
164*da0073e9SAndroid Build Coastguard Worker        random.seed(1)
165*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x)
166*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2))
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        random.seed(10)
169*da0073e9SAndroid Build Coastguard Worker        res1 = fn(x)
170*da0073e9SAndroid Build Coastguard Worker        random.seed(10)
171*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x)
172*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2))
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def test_random_object(self):
175*da0073e9SAndroid Build Coastguard Worker        # test argument passing, mutation, reconstruction, state correctness
176*da0073e9SAndroid Build Coastguard Worker        def fn(x, rand2):
177*da0073e9SAndroid Build Coastguard Worker            r1 = random.randint(1, 9)
178*da0073e9SAndroid Build Coastguard Worker            r2 = rand2.randint(1, 9)
179*da0073e9SAndroid Build Coastguard Worker            rand3 = random.Random(42)
180*da0073e9SAndroid Build Coastguard Worker            r3 = rand3.randint(1, 9)
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker            y = x + r1 + r2 + r3
183*da0073e9SAndroid Build Coastguard Worker            return y, rand2, rand3
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
186*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
187*da0073e9SAndroid Build Coastguard Worker        random.seed(0)
188*da0073e9SAndroid Build Coastguard Worker        y_1, rand2_1, rand3_1 = fn(inp, random.Random(12))
189*da0073e9SAndroid Build Coastguard Worker        state_1 = random.getstate()
190*da0073e9SAndroid Build Coastguard Worker        random.seed(0)
191*da0073e9SAndroid Build Coastguard Worker        y_2, rand2_2, rand3_2 = opt_fn(inp, random.Random(12))
192*da0073e9SAndroid Build Coastguard Worker        state_2 = random.getstate()
193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_1, y_2)
194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(state_1, state_2)
195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rand2_1.getstate(), rand2_2.getstate())
196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rand3_1.getstate(), rand3_2.getstate())
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker    def test_random_object_methods(self):
199*da0073e9SAndroid Build Coastguard Worker        def fn(x, rand1, rand2, rand3):
200*da0073e9SAndroid Build Coastguard Worker            rand1.seed(42)
201*da0073e9SAndroid Build Coastguard Worker            rand4 = random.Random(9002)
202*da0073e9SAndroid Build Coastguard Worker            rand2.setstate(rand4.getstate())
203*da0073e9SAndroid Build Coastguard Worker            r1 = rand1.random()
204*da0073e9SAndroid Build Coastguard Worker            r2 = rand2.randint(1, 10)
205*da0073e9SAndroid Build Coastguard Worker            r3 = rand3.randrange(10)
206*da0073e9SAndroid Build Coastguard Worker            r4 = rand4.uniform(0, 1)
207*da0073e9SAndroid Build Coastguard Worker            return x + r1 + r2 + r3 + r4
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
210*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
211*da0073e9SAndroid Build Coastguard Worker        rand1_1 = random.Random(1)
212*da0073e9SAndroid Build Coastguard Worker        rand2_1 = random.Random(2)
213*da0073e9SAndroid Build Coastguard Worker        rand3_1 = random.Random(3)
214*da0073e9SAndroid Build Coastguard Worker        rand1_2 = random.Random(1)
215*da0073e9SAndroid Build Coastguard Worker        rand2_2 = random.Random(2)
216*da0073e9SAndroid Build Coastguard Worker        rand3_2 = random.Random(3)
217*da0073e9SAndroid Build Coastguard Worker        y1 = fn(inp, rand1_1, rand2_1, rand3_1)
218*da0073e9SAndroid Build Coastguard Worker        y2 = opt_fn(inp, rand1_2, rand2_2, rand3_2)
219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rand1_1.getstate(), rand1_2.getstate())
221*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rand2_1.getstate(), rand2_2.getstate())
222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rand3_1.getstate(), rand3_2.getstate())
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    def test_random_object_overriden_methods(self):
225*da0073e9SAndroid Build Coastguard Worker        # these will result in graph breaks, but we shouldn't crash
226*da0073e9SAndroid Build Coastguard Worker        def get_rng():
227*da0073e9SAndroid Build Coastguard Worker            rand1 = random.Random(1)
228*da0073e9SAndroid Build Coastguard Worker            rand2 = random.Random(2)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker            orig_random = rand1.random
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker            def custom_random():
233*da0073e9SAndroid Build Coastguard Worker                return orig_random()
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker            orig_getstate = rand2.getstate
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker            def custom_getstate():
238*da0073e9SAndroid Build Coastguard Worker                return orig_getstate()
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker            rand1.random = custom_random
241*da0073e9SAndroid Build Coastguard Worker            rand2.getstate = custom_getstate
242*da0073e9SAndroid Build Coastguard Worker            return rand1, rand2
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker        def fn(x, rand1, rand2):
245*da0073e9SAndroid Build Coastguard Worker            r1 = rand1.random()
246*da0073e9SAndroid Build Coastguard Worker            rand3 = random.Random()
247*da0073e9SAndroid Build Coastguard Worker            rand3.setstate(rand2.getstate())
248*da0073e9SAndroid Build Coastguard Worker            r2 = rand3.random()
249*da0073e9SAndroid Build Coastguard Worker            return x + r1 + r2
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
252*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
253*da0073e9SAndroid Build Coastguard Worker        y1 = fn(inp, *get_rng())
254*da0073e9SAndroid Build Coastguard Worker        y2 = opt_fn(inp, *get_rng())
255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    def test_builtin_getitem(self):
258*da0073e9SAndroid Build Coastguard Worker        # builtin getitem args[0] is python list and args[1] is unspec
259*da0073e9SAndroid Build Coastguard Worker        def fn(x, idx):
260*da0073e9SAndroid Build Coastguard Worker            return (torch.zeros(idx), x[idx], x[idx:])
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        x = list(range(50))
263*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, 48)  # 48 is unspecialized
264*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
265*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
266*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, 48)
267*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker    def test_use_and_specialize(self):
270*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
273*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
274*da0073e9SAndroid Build Coastguard Worker            x = x + y
275*da0073e9SAndroid Build Coastguard Worker            if y == 2:
276*da0073e9SAndroid Build Coastguard Worker                return x - 1
277*da0073e9SAndroid Build Coastguard Worker            else:
278*da0073e9SAndroid Build Coastguard Worker                return x + 1
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(torch.tensor([5]), 2), 6))
281*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(torch.tensor([6]), 2), 7))
282*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(torch.tensor([5]), 3), 9))
283*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(torch.tensor([4]), 3), 8))
284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    def test_no_recompiles(self):
287*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
290*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
291*da0073e9SAndroid Build Coastguard Worker            return x + y
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(torch.tensor([5]), 100), 105))
294*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(torch.tensor([4]), 200), 204))
295*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(torch.tensor([3]), 300), 303))
296*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(torch.tensor([2]), 400), 402))
297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker    def test_no_recompiles_prod_backward(self):
301*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/120608
302*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
305*da0073e9SAndroid Build Coastguard Worker        def fn(t):
306*da0073e9SAndroid Build Coastguard Worker            return torch.prod(t, 3, keepdim=True)
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker        input_shapes = [(8, 10, 3, 2), (8, 3, 5, 2), (8, 4, 8, 2)]
309*da0073e9SAndroid Build Coastguard Worker        for s in input_shapes:
310*da0073e9SAndroid Build Coastguard Worker            t1 = torch.randn(s, requires_grad=True)
311*da0073e9SAndroid Build Coastguard Worker            h_result = fn(t1)
312*da0073e9SAndroid Build Coastguard Worker            grad = torch.ones_like(h_result)
313*da0073e9SAndroid Build Coastguard Worker            h_result.backward(grad)
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
316*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
319*da0073e9SAndroid Build Coastguard Worker    def test_builtin_functions_on_cuda(self):
320*da0073e9SAndroid Build Coastguard Worker        def fn(x, scaler):
321*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.ReLU()
322*da0073e9SAndroid Build Coastguard Worker            y = m(x) * scaler
323*da0073e9SAndroid Build Coastguard Worker            return y
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([3, 6], device="cuda")
326*da0073e9SAndroid Build Coastguard Worker        scaler = 0.23  # 0.23 is unspecialized
327*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, scaler)
328*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
329*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
330*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, scaler)
331*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
332*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref.device, res.device)
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker    def test_unspec_float_precision(self):
335*da0073e9SAndroid Build Coastguard Worker        def fn(image, scale_factor):
336*da0073e9SAndroid Build Coastguard Worker            image = torch.nn.functional.interpolate(
337*da0073e9SAndroid Build Coastguard Worker                image[None],
338*da0073e9SAndroid Build Coastguard Worker                size=None,
339*da0073e9SAndroid Build Coastguard Worker                scale_factor=scale_factor,
340*da0073e9SAndroid Build Coastguard Worker                mode="bilinear",
341*da0073e9SAndroid Build Coastguard Worker                recompute_scale_factor=True,
342*da0073e9SAndroid Build Coastguard Worker                align_corners=False,
343*da0073e9SAndroid Build Coastguard Worker            )[0]
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker            return image.shape
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([3, 427, 640])
348*da0073e9SAndroid Build Coastguard Worker        scale_factor = 1.873536229133606
349*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, scale_factor)
350*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
351*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
352*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, scale_factor)
353*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure  # fails as long as numpy scalars are 0D arrays
356*da0073e9SAndroid Build Coastguard Worker    def test_specializing_numpy_float_in_control_flow(self):
357*da0073e9SAndroid Build Coastguard Worker        # np.float64 is unspecialized by default,
358*da0073e9SAndroid Build Coastguard Worker        # but it should be specialized when used in control flow.
359*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
360*da0073e9SAndroid Build Coastguard Worker            if y > 1.0:
361*da0073e9SAndroid Build Coastguard Worker                return x + 1
362*da0073e9SAndroid Build Coastguard Worker            else:
363*da0073e9SAndroid Build Coastguard Worker                return x - 1
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
366*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
367*da0073e9SAndroid Build Coastguard Worker        for t in [np.float16, np.float32, np.float64]:
368*da0073e9SAndroid Build Coastguard Worker            y = t(1.23)
369*da0073e9SAndroid Build Coastguard Worker            ref = fn(x, y)
370*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x, y)
371*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref, res))
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    def test_mark_static_inside(self):
374*da0073e9SAndroid Build Coastguard Worker        def fn(x):
375*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.mark_static(x, 0)
376*da0073e9SAndroid Build Coastguard Worker            comptime.assert_static(x.size(0))
377*da0073e9SAndroid Build Coastguard Worker            return x + 1
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, dynamic=True, fullgraph=True)
380*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(12, 23))
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    def test_shape_graph_break(self):
383*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.comptime import comptime
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        def fn(x):
386*da0073e9SAndroid Build Coastguard Worker            x_shape = x.size()
387*da0073e9SAndroid Build Coastguard Worker            comptime.graph_break()
388*da0073e9SAndroid Build Coastguard Worker            return x + torch.randn(x_shape)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(20)
391*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
392*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker    def test_isinstance_symint(self):
395*da0073e9SAndroid Build Coastguard Worker        def fn(x):
396*da0073e9SAndroid Build Coastguard Worker            assert isinstance(x.size(0), int)
397*da0073e9SAndroid Build Coastguard Worker            return x * 2
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(20)
400*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
401*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
402*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(30)
403*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(y, 0)
404*da0073e9SAndroid Build Coastguard Worker        opt_fn(y)
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    def test_mark_01_dynamic(self):
407*da0073e9SAndroid Build Coastguard Worker        def fn(x):
408*da0073e9SAndroid Build Coastguard Worker            return x * 2
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
411*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(x, 0)
412*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
413*da0073e9SAndroid Build Coastguard Worker        # This will fail to compile a generic kernel, but we should not
414*da0073e9SAndroid Build Coastguard Worker        # complain about it (mark dynamic will try its best but 0/1
415*da0073e9SAndroid Build Coastguard Worker        # specialization is allowed)
416*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_symint_padding(self):
419*da0073e9SAndroid Build Coastguard Worker        kernel = torch.randn(1, 1, 4)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker        def func(x):
422*da0073e9SAndroid Build Coastguard Worker            padding = math.ceil((kernel.shape[-1] + x.shape[-1] % 2) / 2) - 1
423*da0073e9SAndroid Build Coastguard Worker            out = F.conv1d(x, kernel, padding=padding, stride=2)
424*da0073e9SAndroid Build Coastguard Worker            return out
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker        opt_func = torch.compile(func)
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1, 175)
429*da0073e9SAndroid Build Coastguard Worker        opt_func(x)  # passes
430*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1, 249)
431*da0073e9SAndroid Build Coastguard Worker        opt_func(x)  # crashes
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("assume_static_by_default", True)
434*da0073e9SAndroid Build Coastguard Worker    def test_propagate_dynamic_dim(self):
435*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(20)
436*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(x, 0)
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        @torch.compile()
439*da0073e9SAndroid Build Coastguard Worker        def fn(x):
440*da0073e9SAndroid Build Coastguard Worker            y = x * 2
441*da0073e9SAndroid Build Coastguard Worker            comptime.graph_break()
442*da0073e9SAndroid Build Coastguard Worker            z = y * 2
443*da0073e9SAndroid Build Coastguard Worker            return z
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker        z = fn(x)
446*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z._dynamo_weak_dynamic_indices, {0})
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker    def test_rshift_dynamic(self):
449*da0073e9SAndroid Build Coastguard Worker        def shift_right(tensor: torch.Tensor) -> torch.Tensor:
450*da0073e9SAndroid Build Coastguard Worker            return (tensor >> 2).to(torch.long)
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(shift_right, fullgraph=True, dynamic=True)
453*da0073e9SAndroid Build Coastguard Worker        sample_input = torch.tensor([4, 4, 16, 32], dtype=torch.uint8)
454*da0073e9SAndroid Build Coastguard Worker        opt_fn(sample_input)
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
457*da0073e9SAndroid Build Coastguard Worker    def test_symfloat_to_tensor(self):
458*da0073e9SAndroid Build Coastguard Worker        def f1(v):
459*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([v.item()])
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker        def f2(v):
462*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([[v.item()], [2.0]])
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker        def f3(v):
465*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(v.item())
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker        def f4(v):
468*da0073e9SAndroid Build Coastguard Worker            return torch.tensor((v.item(),))
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker        optimize = torch.compile(backend="aot_eager", fullgraph=True)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        r = torch.randn(1)
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f1(r), optimize(f1)(r))
475*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f2(r), optimize(f2)(r))
476*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f3(r), optimize(f3)(r))
477*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f4(r), optimize(f4)(r))
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker    @skipIfWindows(
480*da0073e9SAndroid Build Coastguard Worker        msg="AssertionError: The values for attribute 'dtype' do not match: torch.int32 != torch.int64."
481*da0073e9SAndroid Build Coastguard Worker    )
482*da0073e9SAndroid Build Coastguard Worker    def test_to_tensor(self):
483*da0073e9SAndroid Build Coastguard Worker        def f1():
484*da0073e9SAndroid Build Coastguard Worker            a = np.random.uniform(low=-1, high=1, size=(20, 1))
485*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu")
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        def f2():
488*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([[[123]]])
489*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([a, a])
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker        def f3():
492*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(123)
493*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([a, a])
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker        def f4():
496*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(123)
497*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([[[456]]])
498*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([a, b])
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker        def f5():
501*da0073e9SAndroid Build Coastguard Worker            a = np.array([1, 2])
502*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([a, a])
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker        optimize = torch.compile(backend="aot_eager", fullgraph=True)
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f1().shape, optimize(f1)().shape)
507*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f2(), optimize(f2)())
508*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f3(), optimize(f3)())
509*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f4(), optimize(f4)())
510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f5(), optimize(f5)())
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    def test_sym_int_conversion(self):
513*da0073e9SAndroid Build Coastguard Worker        def f(x):
514*da0073e9SAndroid Build Coastguard Worker            y = x.size(0)
515*da0073e9SAndroid Build Coastguard Worker            return x * int(y == 0)
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(f, backend="eager", fullgraph=True)
518*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
519*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker    def test_sum_dimlist_spec(self):
522*da0073e9SAndroid Build Coastguard Worker        def fn(inputs, dim):
523*da0073e9SAndroid Build Coastguard Worker            return torch.sum(inputs, dim)
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(128, 5, 24, 24)
526*da0073e9SAndroid Build Coastguard Worker        dim = (-1, 1, 0, 2)
527*da0073e9SAndroid Build Coastguard Worker        compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True)
528*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compl_fn(inputs, dim), fn(inputs, dim))
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
531*da0073e9SAndroid Build Coastguard Worker    def test_item_max(self):
532*da0073e9SAndroid Build Coastguard Worker        def fn(x):
533*da0073e9SAndroid Build Coastguard Worker            return torch.ones(max(x.item(), 1024))
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1000])
536*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([2000])
537*da0073e9SAndroid Build Coastguard Worker        compl_fn = torch.compile(fn, backend="eager", fullgraph=True)
538*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), compl_fn(x))
539*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(y), compl_fn(y))
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/104812
542*da0073e9SAndroid Build Coastguard Worker    def test_argmin_coerces_symint_to_intlist_spec(self):
543*da0073e9SAndroid Build Coastguard Worker        def fn(x, dim):
544*da0073e9SAndroid Build Coastguard Worker            # the python arg parser coerces dim into a vector<int>
545*da0073e9SAndroid Build Coastguard Worker            return torch.amin(x, dim=dim, keepdim=True)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, 4)
548*da0073e9SAndroid Build Coastguard Worker        dim = 2
549*da0073e9SAndroid Build Coastguard Worker        compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True)
550*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compl_fn(x, dim), fn(x, dim))
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker    def test_exponential(self):
553*da0073e9SAndroid Build Coastguard Worker        def fn(inputs, op_inputs_dict):
554*da0073e9SAndroid Build Coastguard Worker            res = inputs.exponential_(**op_inputs_dict)
555*da0073e9SAndroid Build Coastguard Worker            return res
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(2, 3, 4)
558*da0073e9SAndroid Build Coastguard Worker        op_inputs_dict = {"lambd": 10, "generator": None}
559*da0073e9SAndroid Build Coastguard Worker        compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True)
560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compl_fn(inputs, op_inputs_dict), fn(inputs, op_inputs_dict))
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker    def test_symbol_guard_limit_before_specialize(self):
563*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnts, dynamic=True)
566*da0073e9SAndroid Build Coastguard Worker        def fn(x):
567*da0073e9SAndroid Build Coastguard Worker            torch._check(x.size(0) != 3)
568*da0073e9SAndroid Build Coastguard Worker            torch._check(x.size(0) != 4)
569*da0073e9SAndroid Build Coastguard Worker            torch._check(x.size(0) != 5)
570*da0073e9SAndroid Build Coastguard Worker            torch._check(x.size(0) != 6)
571*da0073e9SAndroid Build Coastguard Worker            return x + 2
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker        # Control test
574*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(12))
575*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(13))
576*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(14))
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnts.frame_count, """1""")
579*da0073e9SAndroid Build Coastguard Worker        cnts.frame_count = 0
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker        with torch.fx.experimental._config.patch(
584*da0073e9SAndroid Build Coastguard Worker            symbol_guard_limit_before_specialize=3
585*da0073e9SAndroid Build Coastguard Worker        ):
586*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(12))
587*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(13))
588*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(14))
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.frame_count, """3""")
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker    def test_defaults(self):
593*da0073e9SAndroid Build Coastguard Worker        def g(x, i=8):
594*da0073e9SAndroid Build Coastguard Worker            comptime.assert_static(i)
595*da0073e9SAndroid Build Coastguard Worker            return x * i
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker        def fn(x):
598*da0073e9SAndroid Build Coastguard Worker            return g(x)
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(2, 3, 4)
601*da0073e9SAndroid Build Coastguard Worker        compl_fn = torch.compile(fn, dynamic=True, backend="eager")
602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(compl_fn(inputs), fn(inputs))
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True)
605*da0073e9SAndroid Build Coastguard Worker    def test_unspec_float_input(self):
606*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
609*da0073e9SAndroid Build Coastguard Worker            if y == 5.0:
610*da0073e9SAndroid Build Coastguard Worker                return x + 2
611*da0073e9SAndroid Build Coastguard Worker            else:
612*da0073e9SAndroid Build Coastguard Worker                return x + y
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker        cf = torch.compile(backend=cnts, fullgraph=True)(f)
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
617*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x, 3.0), cf(x, 3.0))
618*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x, 4.0), cf(x, 4.0))
619*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnts.frame_count, """1""")  # no recompile
620*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x, 5.0), cf(x, 5.0))
621*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnts.frame_count, """2""")  # guard worked
622*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x, math.nan), cf(x, math.nan))
623*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnts.frame_count, """3""")  # nan always recompiles
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True)
626*da0073e9SAndroid Build Coastguard Worker    def test_unspec_float_output(self):
627*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
630*da0073e9SAndroid Build Coastguard Worker            return x + 1, y * 2
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker        cf = torch.compile(backend=cnts, fullgraph=True)(f)
633*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x, 3.0), cf(x, 3.0))
636*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x, 4.0), cf(x, 4.0))
637*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x, 5.0), cf(x, 5.0))
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
640*da0073e9SAndroid Build Coastguard Worker    def test_data_dependent_evaluate_expr_graph_break(self):
641*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker        # To ensure that the continuation frame is compiled,
644*da0073e9SAndroid Build Coastguard Worker        # have to write the test function in this funny way.
645*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/111918
646*da0073e9SAndroid Build Coastguard Worker        def test(y):
647*da0073e9SAndroid Build Coastguard Worker            if y > 2:
648*da0073e9SAndroid Build Coastguard Worker                return True
649*da0073e9SAndroid Build Coastguard Worker            else:
650*da0073e9SAndroid Build Coastguard Worker                return False
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnts)
653*da0073e9SAndroid Build Coastguard Worker        def fn(x):
654*da0073e9SAndroid Build Coastguard Worker            x = x + 1
655*da0073e9SAndroid Build Coastguard Worker            y = x.item()
656*da0073e9SAndroid Build Coastguard Worker            if test(y):
657*da0073e9SAndroid Build Coastguard Worker                return x * 2
658*da0073e9SAndroid Build Coastguard Worker            else:
659*da0073e9SAndroid Build Coastguard Worker                return x * 3
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([3.0])
662*da0073e9SAndroid Build Coastguard Worker        fn(x)
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnts.frame_count, """2""")
665*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnts.op_count, """4""")
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker    def test_prune_torch_check(self):
668*da0073e9SAndroid Build Coastguard Worker        log_stream, ctx = logs_to_string("torch._dynamo.output_graph", "graph_code")
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True, dynamic=True, backend="eager")
671*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
672*da0073e9SAndroid Build Coastguard Worker            torch._check(y + 5 == 85)
673*da0073e9SAndroid Build Coastguard Worker            torch._check(x.size(0) == 80)
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker        with ctx():
676*da0073e9SAndroid Build Coastguard Worker            f(torch.randn(80, 100), 80)
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        out = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
679*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
680*da0073e9SAndroid Build Coastguard Worker            out,
681*da0073e9SAndroid Build Coastguard Worker            """\
682*da0073e9SAndroid Build Coastguard Workerdef forward(self):
683*da0073e9SAndroid Build Coastguard Worker        return ()""",
684*da0073e9SAndroid Build Coastguard Worker        )
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
687*da0073e9SAndroid Build Coastguard Worker    def test_split_aot_autograd(self):
688*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager", fullgraph=True)
689*da0073e9SAndroid Build Coastguard Worker        def f(x, i):
690*da0073e9SAndroid Build Coastguard Worker            y, z = i.tolist()
691*da0073e9SAndroid Build Coastguard Worker            return torch.split(x, [y, z])
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        print(f(torch.randn(10, requires_grad=True), torch.tensor([7, 3])))
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker    def test_bool_tensor_ctor(self):
696*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnts, dynamic=True, fullgraph=True)
699*da0073e9SAndroid Build Coastguard Worker        def f(x):
700*da0073e9SAndroid Build Coastguard Worker            y = torch.empty((x.size(0) // 13) * 13)
701*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(y.numel() == 0)
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(f(torch.empty(8)).item())
704*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(f(torch.empty(13)).item())
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(error_on_recompile=True)
707*da0073e9SAndroid Build Coastguard Worker    def test_mark_unbacked(self):
708*da0073e9SAndroid Build Coastguard Worker        class TestModel(torch.nn.Module):
709*da0073e9SAndroid Build Coastguard Worker            def __init__(
710*da0073e9SAndroid Build Coastguard Worker                self,
711*da0073e9SAndroid Build Coastguard Worker            ):
712*da0073e9SAndroid Build Coastguard Worker                super().__init__()
713*da0073e9SAndroid Build Coastguard Worker
714*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
715*da0073e9SAndroid Build Coastguard Worker                return x * 2
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker        main_model = TestModel()
718*da0073e9SAndroid Build Coastguard Worker        opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True)
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker        x1 = torch.rand(3, 5, 4, 8)
721*da0073e9SAndroid Build Coastguard Worker        x2 = torch.rand(1, 5, 4, 8)
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.decorators.mark_unbacked(x1, 0)
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker        o1_ref = main_model(x1, 2)
726*da0073e9SAndroid Build Coastguard Worker        o1 = opt_model(x1, 2)
727*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o1_ref, o1)
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker        o1_2_ref = main_model(x2, 2)
730*da0073e9SAndroid Build Coastguard Worker        o1_2 = opt_model(x2, 2)
731*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o1_2_ref, o1_2)
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(error_on_recompile=True)
734*da0073e9SAndroid Build Coastguard Worker    def test_mark_unbacked_hint_consistency(self):
735*da0073e9SAndroid Build Coastguard Worker        from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
738*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.decorators.mark_unbacked(x, 0)
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker        @torch.compile()
741*da0073e9SAndroid Build Coastguard Worker        def f(x):
742*da0073e9SAndroid Build Coastguard Worker            if guard_size_oblivious(x.size(0) != 1):
743*da0073e9SAndroid Build Coastguard Worker                return x + 3
744*da0073e9SAndroid Build Coastguard Worker            else:
745*da0073e9SAndroid Build Coastguard Worker                return x + 4
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x), x + 3)
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(error_on_recompile=True)
750*da0073e9SAndroid Build Coastguard Worker    def test_mark_unbacked_channels_last(self):
751*da0073e9SAndroid Build Coastguard Worker        class TestModel(torch.nn.Module):
752*da0073e9SAndroid Build Coastguard Worker            def __init__(
753*da0073e9SAndroid Build Coastguard Worker                self,
754*da0073e9SAndroid Build Coastguard Worker            ):
755*da0073e9SAndroid Build Coastguard Worker                super().__init__()
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
758*da0073e9SAndroid Build Coastguard Worker                return x * 2
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker        main_model = TestModel()
761*da0073e9SAndroid Build Coastguard Worker        opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True)
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker        x1 = torch.rand(3, 5, 4, 8).to(memory_format=torch.channels_last)
764*da0073e9SAndroid Build Coastguard Worker        x2 = torch.rand(1, 5, 4, 8).to(memory_format=torch.channels_last)
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.decorators.mark_unbacked(x1, 0)
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker        o1_ref = main_model(x1, 2)
769*da0073e9SAndroid Build Coastguard Worker        o1 = opt_model(x1, 2)
770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o1_ref, o1)
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker        o1_2_ref = main_model(x2, 2)
773*da0073e9SAndroid Build Coastguard Worker        o1_2 = opt_model(x2, 2)
774*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o1_2_ref, o1_2)
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
778*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker    run_tests()
781