xref: /aosp_15_r20/external/pytorch/test/dynamo/test_unspec.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import math
3import random
4import unittest
5
6import numpy as np
7
8import torch
9import torch._dynamo.test_case
10import torch._dynamo.testing
11import torch.nn.functional as F
12from torch._dynamo.comptime import comptime
13from torch._dynamo.testing import CompileCounter, same
14from torch.testing._internal.common_utils import skipIfWindows
15from torch.testing._internal.logging_utils import logs_to_string
16
17
18# The intention of this test file is you should put test cases specifically
19# for assume_static_by_default=False, aka you want to YOLO make everything as
20# dynamic as possible.  If you want to test the more normal situation where
21# you assume static by default, put it in a regular test file and
22# test_dynamic_shapes will cover both the YOLO and non-YOLO cases.
23
24
25@torch._dynamo.config.patch(assume_static_by_default=False)
26class UnspecTests(torch._dynamo.test_case.TestCase):
27    def test_numpy_correctness(self):
28        def fn(x, y, z):
29            xy = [x + y, y, False]
30            np_x = x.numpy()
31            np_y = y.numpy()
32            return {
33                "x": x,
34                "z": z,
35                "a": np_y.sum(),
36                "b": xy,
37                "c": np_y[0][0] / 68,
38                "d": np_x.sum(),
39                "e": np_x + np_y,
40            }, x + np_y.sum() + z
41
42        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
43        y = torch.ones([2, 2], dtype=torch.int64)
44        z = np.int64(12)
45        res1 = fn(x, y, z)
46        cnts = torch._dynamo.testing.CompileCounter()
47        opt_fn = torch._dynamo.optimize(cnts)(fn)
48        res2 = opt_fn(x, y, z)
49        self.assertEqual(res1, res2)
50
51    def test_no_recompilations(self):
52        # no recompilations if passing on different numpy int values
53        def fn(x, y):
54            return {"a": x + 1, "b": y / 2}
55
56        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
57        cnts = torch._dynamo.testing.CompileCounter()
58        opt_fn = torch._dynamo.optimize(cnts)(fn)
59        for i in range(10):
60            opt_fn(x, np.int64(i))
61        self.assertEqual(cnts.frame_count, 1)
62        self.assertEqual(cnts.op_count, 2)
63
64    @unittest.expectedFailure  # array scalars decay to 0D arrays
65    def test_builtin_max_min(self):
66        # test unspecialized primitive max/min
67        def fn(x, y, z):
68            return z + 1, max(x, y), min(x - 4, y)
69
70        x = np.int64(12)
71        y = 10
72        z = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64)
73        res1 = fn(x, y, z)
74        cnts = torch._dynamo.testing.CompileCounter()
75        opt_fn = torch._dynamo.optimize(cnts)(fn)
76        res2 = opt_fn(x, y, z)
77        self.assertTrue(same(res1, res2, relax_numpy_equality=True))
78
79    def test_feed_random_values_into_graph_only(self):
80        def fn(shape):
81            torch.manual_seed(123)
82            x = torch.randn(shape, device="cpu") * random.randint(30, 100)
83            return x
84
85        shape = [2, 3]
86        random.seed(1)
87        res1 = fn(shape)
88        cnts = torch._dynamo.testing.CompileCounter()
89        opt_fn = torch._dynamo.optimize(cnts)(fn)
90        random.seed(1)
91        res2 = opt_fn(shape)
92
93        self.assertTrue(same(res1, res2))
94
95    def test_random_values_with_graph_break(self):
96        def fn(x):
97            r1 = random.random()
98            y = x + random.uniform(10, 20)
99            y.sum().item()
100            r2 = random.randint(2, 18)  # no graph output in this frame
101            y.sum().item()
102            return y + r1, r2
103
104        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
105        random.seed(1)
106        res1 = fn(x)
107        cnts = torch._dynamo.testing.CompileCounter()
108        opt_fn = torch._dynamo.optimize(cnts)(fn)
109        random.seed(1)
110        res2 = opt_fn(x)
111        self.assertTrue(same(res1, res2))
112
113    # Really annoying intersection of specialization and RandomValueSource
114    # If we get a RandomValueSource with a single element tensor, we should return a ConstantVariable like other
115    # unspects... but if we do, we break the bytecode assumptions and guards will not work as we will be referring
116    # to a name from a source that is not there. If we call .item() and take the wrapped_value out, where we do
117    # wrapped_value = wrapped_value.item() where we send unspec down to wrap_fx_proxy, this test passes and then
118    # some models fail on missing codegen.tx.output.random_values_var. If we let the tensor value go into wrap as
119    # it is, this test fails.
120    # The real solution here is to rewrite RandomValueSource and all the codegen it does from the ground up.
121    def test_multiple_consecutive_random_calls_before_graph(self):
122        def fn(x):
123            dim1 = random.randrange(start=0, stop=5)
124            dim2 = random.randrange(start=0, stop=5)
125            dim3 = random.randrange(start=0, stop=5)
126            y = torch.rand(dim1, dim2, dim3)
127            return x + 2, y
128
129        x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
130        random.seed(1)
131        res1 = fn(x)
132        cnts = torch._dynamo.testing.CompileCounter()
133        opt_fn = torch._dynamo.optimize(cnts)(fn)
134        random.seed(1)
135        res2 = opt_fn(x)
136        self.assertTrue(same(res1, res2))
137
138    def test_compiled_random_calls_are_random(self):
139        # For compiled functions with random calls,
140        # it should return different values for every iteration.
141        # https://github.com/pytorch/pytorch/issues/95425
142        @torch.compile(backend="eager", fullgraph=True)
143        def fn(x):
144            return (x + 1) * random.uniform(0, 1)
145
146        res = []
147        for _ in range(5):
148            res.append(fn(torch.ones(2)))
149        for i in range(1, 5):
150            self.assertFalse(same(res[i - 1], res[i]))
151
152    def test_random_call_with_while_loop(self):
153        def fn(x):
154            dim1 = random.randrange(start=0, stop=3)
155            dim2 = dim1
156            while dim1 == dim2:
157                dim2 = random.randrange(start=0, stop=3)
158            return x * 2
159
160        x = torch.randn(4)
161        random.seed(1)
162        res1 = fn(x)
163        opt_fn = torch._dynamo.optimize("eager")(fn)
164        random.seed(1)
165        res2 = opt_fn(x)
166        self.assertTrue(same(res1, res2))
167
168        random.seed(10)
169        res1 = fn(x)
170        random.seed(10)
171        res2 = opt_fn(x)
172        self.assertTrue(same(res1, res2))
173
174    def test_random_object(self):
175        # test argument passing, mutation, reconstruction, state correctness
176        def fn(x, rand2):
177            r1 = random.randint(1, 9)
178            r2 = rand2.randint(1, 9)
179            rand3 = random.Random(42)
180            r3 = rand3.randint(1, 9)
181
182            y = x + r1 + r2 + r3
183            return y, rand2, rand3
184
185        inp = torch.randn(3, 3)
186        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
187        random.seed(0)
188        y_1, rand2_1, rand3_1 = fn(inp, random.Random(12))
189        state_1 = random.getstate()
190        random.seed(0)
191        y_2, rand2_2, rand3_2 = opt_fn(inp, random.Random(12))
192        state_2 = random.getstate()
193        self.assertEqual(y_1, y_2)
194        self.assertEqual(state_1, state_2)
195        self.assertEqual(rand2_1.getstate(), rand2_2.getstate())
196        self.assertEqual(rand3_1.getstate(), rand3_2.getstate())
197
198    def test_random_object_methods(self):
199        def fn(x, rand1, rand2, rand3):
200            rand1.seed(42)
201            rand4 = random.Random(9002)
202            rand2.setstate(rand4.getstate())
203            r1 = rand1.random()
204            r2 = rand2.randint(1, 10)
205            r3 = rand3.randrange(10)
206            r4 = rand4.uniform(0, 1)
207            return x + r1 + r2 + r3 + r4
208
209        inp = torch.randn(3, 3)
210        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
211        rand1_1 = random.Random(1)
212        rand2_1 = random.Random(2)
213        rand3_1 = random.Random(3)
214        rand1_2 = random.Random(1)
215        rand2_2 = random.Random(2)
216        rand3_2 = random.Random(3)
217        y1 = fn(inp, rand1_1, rand2_1, rand3_1)
218        y2 = opt_fn(inp, rand1_2, rand2_2, rand3_2)
219        self.assertEqual(y1, y2)
220        self.assertEqual(rand1_1.getstate(), rand1_2.getstate())
221        self.assertEqual(rand2_1.getstate(), rand2_2.getstate())
222        self.assertEqual(rand3_1.getstate(), rand3_2.getstate())
223
224    def test_random_object_overriden_methods(self):
225        # these will result in graph breaks, but we shouldn't crash
226        def get_rng():
227            rand1 = random.Random(1)
228            rand2 = random.Random(2)
229
230            orig_random = rand1.random
231
232            def custom_random():
233                return orig_random()
234
235            orig_getstate = rand2.getstate
236
237            def custom_getstate():
238                return orig_getstate()
239
240            rand1.random = custom_random
241            rand2.getstate = custom_getstate
242            return rand1, rand2
243
244        def fn(x, rand1, rand2):
245            r1 = rand1.random()
246            rand3 = random.Random()
247            rand3.setstate(rand2.getstate())
248            r2 = rand3.random()
249            return x + r1 + r2
250
251        inp = torch.randn(3, 3)
252        opt_fn = torch.compile(fn, backend="eager")
253        y1 = fn(inp, *get_rng())
254        y2 = opt_fn(inp, *get_rng())
255        self.assertEqual(y1, y2)
256
257    def test_builtin_getitem(self):
258        # builtin getitem args[0] is python list and args[1] is unspec
259        def fn(x, idx):
260            return (torch.zeros(idx), x[idx], x[idx:])
261
262        x = list(range(50))
263        ref = fn(x, 48)  # 48 is unspecialized
264        cnts = torch._dynamo.testing.CompileCounter()
265        opt_fn = torch._dynamo.optimize(cnts)(fn)
266        res = opt_fn(x, 48)
267        self.assertTrue(same(ref, res))
268
269    def test_use_and_specialize(self):
270        cnt = CompileCounter()
271
272        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
273        def fn(x, y):
274            x = x + y
275            if y == 2:
276                return x - 1
277            else:
278                return x + 1
279
280        self.assertTrue(same(fn(torch.tensor([5]), 2), 6))
281        self.assertTrue(same(fn(torch.tensor([6]), 2), 7))
282        self.assertTrue(same(fn(torch.tensor([5]), 3), 9))
283        self.assertTrue(same(fn(torch.tensor([4]), 3), 8))
284        self.assertEqual(cnt.frame_count, 2)
285
286    def test_no_recompiles(self):
287        cnt = CompileCounter()
288
289        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
290        def fn(x, y):
291            return x + y
292
293        self.assertTrue(same(fn(torch.tensor([5]), 100), 105))
294        self.assertTrue(same(fn(torch.tensor([4]), 200), 204))
295        self.assertTrue(same(fn(torch.tensor([3]), 300), 303))
296        self.assertTrue(same(fn(torch.tensor([2]), 400), 402))
297        self.assertEqual(cnt.frame_count, 1)
298        self.assertEqual(cnt.op_count, 1)
299
300    def test_no_recompiles_prod_backward(self):
301        # https://github.com/pytorch/pytorch/issues/120608
302        cnt = CompileCounter()
303
304        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
305        def fn(t):
306            return torch.prod(t, 3, keepdim=True)
307
308        input_shapes = [(8, 10, 3, 2), (8, 3, 5, 2), (8, 4, 8, 2)]
309        for s in input_shapes:
310            t1 = torch.randn(s, requires_grad=True)
311            h_result = fn(t1)
312            grad = torch.ones_like(h_result)
313            h_result.backward(grad)
314
315        self.assertEqual(cnt.frame_count, 1)
316        self.assertEqual(cnt.op_count, 1)
317
318    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
319    def test_builtin_functions_on_cuda(self):
320        def fn(x, scaler):
321            m = torch.nn.ReLU()
322            y = m(x) * scaler
323            return y
324
325        x = torch.randn([3, 6], device="cuda")
326        scaler = 0.23  # 0.23 is unspecialized
327        ref = fn(x, scaler)
328        cnts = torch._dynamo.testing.CompileCounter()
329        opt_fn = torch._dynamo.optimize(cnts)(fn)
330        res = opt_fn(x, scaler)
331        self.assertTrue(same(ref, res))
332        self.assertEqual(ref.device, res.device)
333
334    def test_unspec_float_precision(self):
335        def fn(image, scale_factor):
336            image = torch.nn.functional.interpolate(
337                image[None],
338                size=None,
339                scale_factor=scale_factor,
340                mode="bilinear",
341                recompute_scale_factor=True,
342                align_corners=False,
343            )[0]
344
345            return image.shape
346
347        x = torch.rand([3, 427, 640])
348        scale_factor = 1.873536229133606
349        ref = fn(x, scale_factor)
350        cnts = torch._dynamo.testing.CompileCounter()
351        opt_fn = torch._dynamo.optimize(cnts)(fn)
352        res = opt_fn(x, scale_factor)
353        self.assertTrue(same(ref, res))
354
355    @unittest.expectedFailure  # fails as long as numpy scalars are 0D arrays
356    def test_specializing_numpy_float_in_control_flow(self):
357        # np.float64 is unspecialized by default,
358        # but it should be specialized when used in control flow.
359        def fn(x, y):
360            if y > 1.0:
361                return x + 1
362            else:
363                return x - 1
364
365        x = torch.rand(4)
366        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
367        for t in [np.float16, np.float32, np.float64]:
368            y = t(1.23)
369            ref = fn(x, y)
370            res = opt_fn(x, y)
371            self.assertTrue(same(ref, res))
372
373    def test_mark_static_inside(self):
374        def fn(x):
375            torch._dynamo.mark_static(x, 0)
376            comptime.assert_static(x.size(0))
377            return x + 1
378
379        opt_fn = torch.compile(fn, dynamic=True, fullgraph=True)
380        opt_fn(torch.randn(12, 23))
381
382    def test_shape_graph_break(self):
383        from torch._dynamo.comptime import comptime
384
385        def fn(x):
386            x_shape = x.size()
387            comptime.graph_break()
388            return x + torch.randn(x_shape)
389
390        x = torch.randn(20)
391        opt_fn = torch._dynamo.optimize("eager")(fn)
392        opt_fn(x)
393
394    def test_isinstance_symint(self):
395        def fn(x):
396            assert isinstance(x.size(0), int)
397            return x * 2
398
399        x = torch.randn(20)
400        opt_fn = torch._dynamo.optimize("eager")(fn)
401        opt_fn(x)
402        y = torch.randn(30)
403        torch._dynamo.mark_dynamic(y, 0)
404        opt_fn(y)
405
406    def test_mark_01_dynamic(self):
407        def fn(x):
408            return x * 2
409
410        x = torch.randn(1)
411        torch._dynamo.mark_dynamic(x, 0)
412        opt_fn = torch._dynamo.optimize("eager")(fn)
413        # This will fail to compile a generic kernel, but we should not
414        # complain about it (mark dynamic will try its best but 0/1
415        # specialization is allowed)
416        opt_fn(x)
417
418    def test_conv1d_symint_padding(self):
419        kernel = torch.randn(1, 1, 4)
420
421        def func(x):
422            padding = math.ceil((kernel.shape[-1] + x.shape[-1] % 2) / 2) - 1
423            out = F.conv1d(x, kernel, padding=padding, stride=2)
424            return out
425
426        opt_func = torch.compile(func)
427
428        x = torch.randn(1, 1, 175)
429        opt_func(x)  # passes
430        x = torch.randn(1, 1, 249)
431        opt_func(x)  # crashes
432
433    @torch._dynamo.config.patch("assume_static_by_default", True)
434    def test_propagate_dynamic_dim(self):
435        x = torch.randn(20)
436        torch._dynamo.mark_dynamic(x, 0)
437
438        @torch.compile()
439        def fn(x):
440            y = x * 2
441            comptime.graph_break()
442            z = y * 2
443            return z
444
445        z = fn(x)
446        self.assertEqual(z._dynamo_weak_dynamic_indices, {0})
447
448    def test_rshift_dynamic(self):
449        def shift_right(tensor: torch.Tensor) -> torch.Tensor:
450            return (tensor >> 2).to(torch.long)
451
452        opt_fn = torch.compile(shift_right, fullgraph=True, dynamic=True)
453        sample_input = torch.tensor([4, 4, 16, 32], dtype=torch.uint8)
454        opt_fn(sample_input)
455
456    @torch._dynamo.config.patch(capture_scalar_outputs=True)
457    def test_symfloat_to_tensor(self):
458        def f1(v):
459            return torch.tensor([v.item()])
460
461        def f2(v):
462            return torch.tensor([[v.item()], [2.0]])
463
464        def f3(v):
465            return torch.tensor(v.item())
466
467        def f4(v):
468            return torch.tensor((v.item(),))
469
470        optimize = torch.compile(backend="aot_eager", fullgraph=True)
471
472        r = torch.randn(1)
473
474        self.assertEqual(f1(r), optimize(f1)(r))
475        self.assertEqual(f2(r), optimize(f2)(r))
476        self.assertEqual(f3(r), optimize(f3)(r))
477        self.assertEqual(f4(r), optimize(f4)(r))
478
479    @skipIfWindows(
480        msg="AssertionError: The values for attribute 'dtype' do not match: torch.int32 != torch.int64."
481    )
482    def test_to_tensor(self):
483        def f1():
484            a = np.random.uniform(low=-1, high=1, size=(20, 1))
485            return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu")
486
487        def f2():
488            a = torch.tensor([[[123]]])
489            return torch.tensor([a, a])
490
491        def f3():
492            a = torch.tensor(123)
493            return torch.tensor([a, a])
494
495        def f4():
496            a = torch.tensor(123)
497            b = torch.tensor([[[456]]])
498            return torch.tensor([a, b])
499
500        def f5():
501            a = np.array([1, 2])
502            return torch.tensor([a, a])
503
504        optimize = torch.compile(backend="aot_eager", fullgraph=True)
505
506        self.assertEqual(f1().shape, optimize(f1)().shape)
507        self.assertEqual(f2(), optimize(f2)())
508        self.assertEqual(f3(), optimize(f3)())
509        self.assertEqual(f4(), optimize(f4)())
510        self.assertEqual(f5(), optimize(f5)())
511
512    def test_sym_int_conversion(self):
513        def f(x):
514            y = x.size(0)
515            return x * int(y == 0)
516
517        opt_fn = torch.compile(f, backend="eager", fullgraph=True)
518        x = torch.randn(2, 3)
519        opt_fn(x)
520
521    def test_sum_dimlist_spec(self):
522        def fn(inputs, dim):
523            return torch.sum(inputs, dim)
524
525        inputs = torch.randn(128, 5, 24, 24)
526        dim = (-1, 1, 0, 2)
527        compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True)
528        self.assertEqual(compl_fn(inputs, dim), fn(inputs, dim))
529
530    @torch._dynamo.config.patch(capture_scalar_outputs=True)
531    def test_item_max(self):
532        def fn(x):
533            return torch.ones(max(x.item(), 1024))
534
535        x = torch.tensor([1000])
536        y = torch.tensor([2000])
537        compl_fn = torch.compile(fn, backend="eager", fullgraph=True)
538        self.assertEqual(fn(x), compl_fn(x))
539        self.assertEqual(fn(y), compl_fn(y))
540
541    # https://github.com/pytorch/pytorch/issues/104812
542    def test_argmin_coerces_symint_to_intlist_spec(self):
543        def fn(x, dim):
544            # the python arg parser coerces dim into a vector<int>
545            return torch.amin(x, dim=dim, keepdim=True)
546
547        x = torch.randn(4, 4, 4)
548        dim = 2
549        compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True)
550        self.assertEqual(compl_fn(x, dim), fn(x, dim))
551
552    def test_exponential(self):
553        def fn(inputs, op_inputs_dict):
554            res = inputs.exponential_(**op_inputs_dict)
555            return res
556
557        inputs = torch.randn(2, 3, 4)
558        op_inputs_dict = {"lambd": 10, "generator": None}
559        compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True)
560        self.assertEqual(compl_fn(inputs, op_inputs_dict), fn(inputs, op_inputs_dict))
561
562    def test_symbol_guard_limit_before_specialize(self):
563        cnts = torch._dynamo.testing.CompileCounter()
564
565        @torch._dynamo.optimize(cnts, dynamic=True)
566        def fn(x):
567            torch._check(x.size(0) != 3)
568            torch._check(x.size(0) != 4)
569            torch._check(x.size(0) != 5)
570            torch._check(x.size(0) != 6)
571            return x + 2
572
573        # Control test
574        fn(torch.randn(12))
575        fn(torch.randn(13))
576        fn(torch.randn(14))
577
578        self.assertExpectedInline(cnts.frame_count, """1""")
579        cnts.frame_count = 0
580
581        torch._dynamo.reset()
582
583        with torch.fx.experimental._config.patch(
584            symbol_guard_limit_before_specialize=3
585        ):
586            fn(torch.randn(12))
587            fn(torch.randn(13))
588            fn(torch.randn(14))
589
590            self.assertExpectedInline(cnts.frame_count, """3""")
591
592    def test_defaults(self):
593        def g(x, i=8):
594            comptime.assert_static(i)
595            return x * i
596
597        def fn(x):
598            return g(x)
599
600        inputs = torch.randn(2, 3, 4)
601        compl_fn = torch.compile(fn, dynamic=True, backend="eager")
602        self.assertEqual(compl_fn(inputs), fn(inputs))
603
604    @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True)
605    def test_unspec_float_input(self):
606        cnts = torch._dynamo.testing.CompileCounter()
607
608        def f(x, y):
609            if y == 5.0:
610                return x + 2
611            else:
612                return x + y
613
614        cf = torch.compile(backend=cnts, fullgraph=True)(f)
615
616        x = torch.randn(3)
617        self.assertEqual(f(x, 3.0), cf(x, 3.0))
618        self.assertEqual(f(x, 4.0), cf(x, 4.0))
619        self.assertExpectedInline(cnts.frame_count, """1""")  # no recompile
620        self.assertEqual(f(x, 5.0), cf(x, 5.0))
621        self.assertExpectedInline(cnts.frame_count, """2""")  # guard worked
622        self.assertEqual(f(x, math.nan), cf(x, math.nan))
623        self.assertExpectedInline(cnts.frame_count, """3""")  # nan always recompiles
624
625    @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True)
626    def test_unspec_float_output(self):
627        cnts = torch._dynamo.testing.CompileCounter()
628
629        def f(x, y):
630            return x + 1, y * 2
631
632        cf = torch.compile(backend=cnts, fullgraph=True)(f)
633        x = torch.randn(3)
634
635        self.assertEqual(f(x, 3.0), cf(x, 3.0))
636        self.assertEqual(f(x, 4.0), cf(x, 4.0))
637        self.assertEqual(f(x, 5.0), cf(x, 5.0))
638
639    @torch._dynamo.config.patch(capture_scalar_outputs=True)
640    def test_data_dependent_evaluate_expr_graph_break(self):
641        cnts = torch._dynamo.testing.CompileCounter()
642
643        # To ensure that the continuation frame is compiled,
644        # have to write the test function in this funny way.
645        # See https://github.com/pytorch/pytorch/issues/111918
646        def test(y):
647            if y > 2:
648                return True
649            else:
650                return False
651
652        @torch._dynamo.optimize(cnts)
653        def fn(x):
654            x = x + 1
655            y = x.item()
656            if test(y):
657                return x * 2
658            else:
659                return x * 3
660
661        x = torch.tensor([3.0])
662        fn(x)
663
664        self.assertExpectedInline(cnts.frame_count, """2""")
665        self.assertExpectedInline(cnts.op_count, """4""")
666
667    def test_prune_torch_check(self):
668        log_stream, ctx = logs_to_string("torch._dynamo.output_graph", "graph_code")
669
670        @torch.compile(fullgraph=True, dynamic=True, backend="eager")
671        def f(x, y):
672            torch._check(y + 5 == 85)
673            torch._check(x.size(0) == 80)
674
675        with ctx():
676            f(torch.randn(80, 100), 80)
677
678        out = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
679        self.assertExpectedInline(
680            out,
681            """\
682def forward(self):
683        return ()""",
684        )
685
686    @torch._dynamo.config.patch(capture_scalar_outputs=True)
687    def test_split_aot_autograd(self):
688        @torch.compile(backend="aot_eager", fullgraph=True)
689        def f(x, i):
690            y, z = i.tolist()
691            return torch.split(x, [y, z])
692
693        print(f(torch.randn(10, requires_grad=True), torch.tensor([7, 3])))
694
695    def test_bool_tensor_ctor(self):
696        cnts = torch._dynamo.testing.CompileCounter()
697
698        @torch.compile(backend=cnts, dynamic=True, fullgraph=True)
699        def f(x):
700            y = torch.empty((x.size(0) // 13) * 13)
701            return torch.tensor(y.numel() == 0)
702
703        self.assertTrue(f(torch.empty(8)).item())
704        self.assertFalse(f(torch.empty(13)).item())
705
706    @torch._dynamo.config.patch(error_on_recompile=True)
707    def test_mark_unbacked(self):
708        class TestModel(torch.nn.Module):
709            def __init__(
710                self,
711            ):
712                super().__init__()
713
714            def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
715                return x * 2
716
717        main_model = TestModel()
718        opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True)
719
720        x1 = torch.rand(3, 5, 4, 8)
721        x2 = torch.rand(1, 5, 4, 8)
722
723        torch._dynamo.decorators.mark_unbacked(x1, 0)
724
725        o1_ref = main_model(x1, 2)
726        o1 = opt_model(x1, 2)
727        self.assertEqual(o1_ref, o1)
728
729        o1_2_ref = main_model(x2, 2)
730        o1_2 = opt_model(x2, 2)
731        self.assertEqual(o1_2_ref, o1_2)
732
733    @torch._dynamo.config.patch(error_on_recompile=True)
734    def test_mark_unbacked_hint_consistency(self):
735        from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
736
737        x = torch.randn(1)
738        torch._dynamo.decorators.mark_unbacked(x, 0)
739
740        @torch.compile()
741        def f(x):
742            if guard_size_oblivious(x.size(0) != 1):
743                return x + 3
744            else:
745                return x + 4
746
747        self.assertEqual(f(x), x + 3)
748
749    @torch._dynamo.config.patch(error_on_recompile=True)
750    def test_mark_unbacked_channels_last(self):
751        class TestModel(torch.nn.Module):
752            def __init__(
753                self,
754            ):
755                super().__init__()
756
757            def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
758                return x * 2
759
760        main_model = TestModel()
761        opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True)
762
763        x1 = torch.rand(3, 5, 4, 8).to(memory_format=torch.channels_last)
764        x2 = torch.rand(1, 5, 4, 8).to(memory_format=torch.channels_last)
765
766        torch._dynamo.decorators.mark_unbacked(x1, 0)
767
768        o1_ref = main_model(x1, 2)
769        o1 = opt_model(x1, 2)
770        self.assertEqual(o1_ref, o1)
771
772        o1_2_ref = main_model(x2, 2)
773        o1_2 = opt_model(x2, 2)
774        self.assertEqual(o1_2_ref, o1_2)
775
776
777if __name__ == "__main__":
778    from torch._dynamo.test_case import run_tests
779
780    run_tests()
781