xref: /aosp_15_r20/external/pytorch/test/test_proxy_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: ProxyTensor"]
2
3from torch.testing._internal.common_utils import TestCase, run_tests
4import torch
5import torch._dynamo
6import unittest
7import warnings
8import operator
9from collections.abc import Iterable
10from torch.nn.utils import stateless
11from torch.testing._internal.common_device_type import instantiate_device_type_tests
12from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
13from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
14from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
15from torch._decomp import decomposition_table
16from torch.fx.experimental.symbolic_shapes import (
17    eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
18    guard_int, GuardOnDataDependentSymNode
19)
20from torch.testing._internal.custom_op_db import custom_op_db
21from torch.testing._internal.hop_db import hop_db
22from torch.testing._internal.common_device_type import ops
23import torch.testing._internal.optests as optests
24from torch._C import _disabled_torch_function_impl
25from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
26from torch.utils._pytree import tree_map
27from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
28from torch import nn
29import torch._functorch.config
30import re
31
32import functools
33import itertools
34
35aten = torch.ops.aten
36
37HAS_CUDA = torch.cuda.is_available()
38
39
40def strip_end(s, suffix):
41    if suffix and s.endswith(suffix):
42        return s[:-len(suffix)]
43    else:
44        return s
45
46
47def show_guards(gm):
48    names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
49    return "\n".join(
50        gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None)
51    )
52
53
54def process_failures():
55    """
56    Takes file containing failures like
57
58    FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition  # noqa: B950
59
60    and processes them into a list of opinfo xfails
61    """
62    f = open('pytest_failures')
63    failures = f.readlines()
64    failures = [i.strip() for i in failures]
65
66    def process_failure_string(s, matcher):
67        out = re.search(matcher, s)
68        return out.groups()
69
70    SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)'
71    failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures]
72
73    def create_normalized_name(op):
74        if op.variant_test_name == '':
75            s = op.name
76        else:
77            s = f"{op.name}.{op.variant_test_name}"
78        return s.replace('.', '_')
79
80    remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db}
81
82    print("symbolic_tensor_failures = {")
83    for failure, reason in failures:
84        print(f"    xfail{remap_opinfo[failure]},  # {reason}")
85    print("}")
86
87
88USE_TORCHVISION = False
89try:
90    import torchvision
91    USE_TORCHVISION = True
92except ImportError:
93    warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
94                  "to install it with commands from pytorch.org, post-fixed with "
95                  "`--no-deps` to avoid overwriting the pytorch installation",
96                  UserWarning)
97
98
99def _create_new_input(x):
100    if not isinstance(x, torch.Tensor):
101        return x
102    if x.dtype != torch.float:
103        return x + 1
104    if x.is_leaf:
105        return torch.rand_like(x, requires_grad=x.requires_grad)
106    else:
107        return torch.rand_like(x)
108
109"""
110Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
111"""
112class UnwrapTensor(torch.Tensor):
113    @staticmethod
114    def __new__(cls, tensor: torch.Tensor):
115        r = torch.Tensor._make_wrapper_subclass(
116            cls,
117            tensor.size(),
118            dtype=tensor.dtype,
119            device=tensor.device,
120            layout=tensor.layout,
121            requires_grad=tensor.requires_grad,
122        )
123        r._tensor = tensor
124        return r
125
126    def __repr__(self):
127        # TODO: consider all_gather the local tensors for better debugging
128        return f"UnwrapTensor({self._tensor})"
129
130    __torch_function__ = _disabled_torch_function_impl
131
132    @classmethod
133    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
134        def unwrap(e):
135            ret = e
136            if isinstance(e, UnwrapTensor):
137                ret = e._tensor.cos()
138
139            return ret
140
141        args = tree_map(unwrap, args)
142        kwargs = tree_map(unwrap, kwargs)
143        return func(*args, **kwargs)
144
145class TestGenericProxyTensor(TestCase):
146    # WARNING: if any of your inputs are index tensors, DO NOT use this
147    # function
148    def _test(self, f, inps):
149        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
150        new_inps = tree_map(_create_new_input, inps)
151        r1 = fx_f(*new_inps)
152        r2 = f(*new_inps)
153        self.assertEqual(r1, r2)
154
155    def test_pre_dispatch_mode_stack(self):
156        def f(a):
157            b = torch.ones(4, 4)
158            return torch.matmul(a, b)
159        # We expect to see matmul in the trace - it should NOT be decomposed into mm.
160        # Also, torch.ones() doesn't show up in the trace.
161        # This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
162        # so our mode never sees it - it goes directly to the BackendSelect key.
163        inp = torch.ones(4, 4)
164        # Test that make_fx(pre_dispatch=True) clears caches properly.
165        from torch._dispatch.python import enable_python_dispatcher
166        with enable_python_dispatcher():
167            out1 = f(inp)
168        fx_g = make_fx(f, pre_dispatch=True)(inp)
169        self.assertExpectedInline(fx_g.code.strip(), """\
170def forward(self, a_1):
171    ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
172    matmul = torch.ops.aten.matmul.default(a_1, ones);  a_1 = ones = None
173    return matmul""")
174
175    def test_pre_dispatch_linear(self):
176        def f(a, b, c):
177            return torch.nn.functional.linear(a, b, c)
178        a = torch.ones(4, 4)
179        b = torch.ones(4, 4)
180        c = torch.ones(4)
181        fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
182        out1 = f(a, b, c)
183        out2 = fx_g(a, b, c)
184        self.assertEqual(out1, out2)
185
186    def test_pre_dispatch_no_grad(self):
187        def f(a):
188            b = a.sin()
189            torch.set_grad_enabled(False)
190            c = b.cos()
191            torch.set_grad_enabled(True)
192            return b + c.sin()
193        a1 = torch.randn(4, requires_grad=True)
194        a2 = a1.clone().detach().requires_grad_(True)
195        a_tmp = a1.clone().detach().requires_grad_(True)
196        fx_g = make_fx(f, pre_dispatch=True)(a_tmp)
197        out1 = f(a1)
198        out2 = fx_g(a2)
199        self.assertEqual(out1, out2)
200        out1.sum().backward()
201        out2.sum().backward()
202        self.assertEqual(a1.grad, a2.grad)
203
204    def test_make_fx_simple(self):
205        def f(x):
206            return torch.sin(x)
207        self._test(f, (torch.randn(3),))
208
209    def test_scalar_device(self, device='cpu'):
210        def f(a, b):
211            return a + b
212        self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
213
214    def test_isolated_graphmodule(self):
215        def is_any_sum(gm):
216            return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
217
218        def is_any_digamma(gm):
219            return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
220
221        def is_any_sigmoid(gm):
222            return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
223
224        def inner(x):
225            return torch.sum(x)
226
227        def f(x):
228            gm = get_isolated_graphmodule(inner, (x,), {})
229            self.assertTrue(is_any_sum(gm))
230            return x + torch.randn(x.shape)
231
232        # get_isolated_graphmodule uses make_fx internally that shouldn't be traced
233        # by the outer make_fx call
234        traced = make_fx(f)(torch.randn(3))
235        self.assertFalse(is_any_sum(traced))
236
237        # When factory functions are used, they should not be traced
238        # by the outer make_fx call
239        def inner_with_factory():
240            val = torch.tensor(float(1))
241            val.add_(2)
242            return torch.full((10, 10), val).sum()
243
244        def f1(x):
245            gm = get_isolated_graphmodule(inner_with_factory, (), {})
246            self.assertTrue(is_any_sum(gm))
247            return torch.sigmoid(x)
248
249        def f2(x):
250            gm = get_isolated_graphmodule(f1, (x,), {})
251            self.assertFalse(is_any_sum(gm))
252            self.assertTrue(is_any_sigmoid(gm))
253            return torch.digamma(x)
254
255        traced = make_fx(f2)(torch.randn(3))
256        self.assertFalse(is_any_sum(traced))
257        self.assertFalse(is_any_sigmoid(traced))
258        self.assertTrue(is_any_digamma(traced))
259
260        # Verify nested make_fx calls don't make factory functions to be leaked
261        # into the outer graph. Verify that `make_fx`` itself does not leak its execution.
262        def f2(x):
263            gm = make_fx(f1)(x)
264            self.assertFalse(is_any_sum(gm))
265            self.assertTrue(is_any_sigmoid(gm))
266            return torch.digamma(x)
267
268        traced = make_fx(f2)(torch.randn(3))
269        self.assertFalse(is_any_sum(traced))
270        self.assertFalse(is_any_sigmoid(traced))
271        self.assertTrue(is_any_digamma(traced))
272
273        # Verify that the `forward`` function of a graph module produced as a
274        # side effect of an interior `make_fx` is still traced
275        def f3(x):
276            gm = make_fx(f1)(x)
277            self.assertFalse(is_any_sum(gm))
278            self.assertTrue(is_any_sigmoid(gm))
279            # `gm.forward`` is still traced
280            return torch.digamma(gm(x))
281
282        traced = make_fx(f3)(torch.randn(3))
283        self.assertFalse(is_any_sum(traced))
284        self.assertTrue(is_any_sigmoid(traced))
285        self.assertTrue(is_any_digamma(traced))
286
287        # Verify interaction with non-ProxyTensor modes
288        from torch.testing._internal.logging_tensor import LoggingTensorMode
289
290        def f1_logging(x):
291            with LoggingTensorMode():
292                gm = get_isolated_graphmodule(inner_with_factory, (), {})
293            self.assertTrue(is_any_sum(gm))
294            return torch.sigmoid(x)
295
296        def f2_logging(x):
297            with LoggingTensorMode(), LoggingTensorMode():
298                gm = get_isolated_graphmodule(f1_logging, (x,), {})
299            self.assertFalse(is_any_sum(gm))
300            self.assertTrue(is_any_sigmoid(gm))
301            return torch.digamma(x)
302
303        traced = make_fx(f2_logging)(torch.randn(3))
304        self.assertFalse(is_any_sum(traced))
305        self.assertFalse(is_any_sigmoid(traced))
306        self.assertTrue(is_any_digamma(traced))
307
308        # Verify interaction with another tensor subclass
309        # This case currently doesn't work and should raise an error
310        # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
311        from torch.testing._internal.logging_tensor import LoggingTensor
312
313        def f1_logging_tensor(x):
314            gm = get_isolated_graphmodule(inner_with_factory, (), {})
315            self.assertTrue(is_any_sum(gm))
316            return torch.sigmoid(x)
317
318        def f2_logging_tensor(x):
319            x = LoggingTensor(x)
320            gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
321            self.assertFalse(is_any_sum(gm))
322            self.assertTrue(is_any_sigmoid(gm))
323            return torch.digamma(x)
324
325        traced = make_fx(f2_logging_tensor)(torch.randn(3))
326        self.assertFalse(is_any_sum(traced))
327        self.assertFalse(is_any_sigmoid(traced))  # this fails, sigmoid is traced with LoggingTensor
328        self.assertTrue(is_any_digamma(traced))
329
330    # See https://github.com/pytorch/pytorch/issues/97541
331    def test_empty_like_doesnt_burn_in_defaults(self):
332        def f(x):
333            return torch.empty_like(x)
334        out = make_fx(f)(torch.randn(3))
335        self.assertExpectedInline(out.code.strip(), """\
336def forward(self, x_1):
337    empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False);  x_1 = None
338    return empty_like""")
339
340    def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
341        def f(x):
342            y = x.new_zeros(x.size())
343            y.copy_(x)
344            return y
345
346        def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
347            return torch.zeros(size, dtype=inp.dtype, device=inp.device)
348
349        factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp}
350
351        # When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode
352        # to still be (re-entrantly) enabled, so that the `torch.zero()` call
353        # returns a ProxyTensor.
354        out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2))
355        self.assertExpectedInline(out.code, """\
356
357
358
359def forward(self, x_1):
360    zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
361    copy_ = torch.ops.aten.copy_.default(zeros, x_1);  zeros = x_1 = None
362    return copy_
363    """)
364
365    def test_make_fx_reentrant_dispatch(self):
366        def f(x):
367            return torch.ops.aten.norm.Scalar(x, 2.0)
368
369        def norm_decomp(x, p=2.0):
370            if p != 2.0:
371                raise RuntimeError("can't handle with p != 2")
372            return torch.sqrt(torch.sum(torch.square(x)))
373
374        decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
375
376        traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3))
377
378        for n in traced.graph.nodes:
379            self.assertTrue("square" not in str(n.target))
380            self.assertTrue("norm" not in str(n.target))
381
382    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
383    def test_resnet18_backward_trace(self):
384        mod = torchvision.models.resnet18()
385
386        # An old version of this test called the module directly.  This works
387        # for tracing_mode == "real", but for fake tensors, we also have to
388        # ensure that the parameters and buffers get wrapped in fake tensors
389        # because free fake tensors are not supported.  Fortunately functional_call
390        # does precisely this for us.
391        def f(x, params, buffers):
392            for p in params.values():
393                p.grad = None
394            loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
395            # I could have done this with the functional API, but there is
396            # plenty of exercising this; I want to show mutating API still
397            # works
398            loss.backward()
399            return [p.grad for p in params.values()]
400
401        inp = torch.randn(3, 3, 250, 250)
402        self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
403
404    def test_varargs(self):
405        def f(*args):
406            return sum(args)
407
408        self._test(f, [torch.randn(2), torch.randn(2)])
409
410    def test_proxy_tensor(self):
411        def f_grad(x):
412            val = x.cos().cos().sum()
413            return torch.autograd.grad(val, x)
414
415        def f_backward(x):
416            val = x.cos().cos().sum()
417            val.backward()
418            return x.grad
419
420        for f in [f_grad, f_backward]:
421            self._test(f, [torch.randn(3, requires_grad=True)])
422
423    def test_pickle_issue89626(self):
424        import pickle
425        x = torch.randn(2)
426        make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x)
427        pickle.dumps(x)
428
429    def test_inplace_metadata(self):
430        def f(x):
431            x = x.clone()
432            x.unsqueeze_(-1)
433            assert x.shape[-1] == 1
434            return x
435
436        self._test(f, [torch.randn(5)])
437
438    def test_mode_tracing_factory_function(self):
439        def f(x):
440            return x + torch.randn(x.shape)
441
442        # default behavior should trace factory functions
443        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
444        self.assertTrue(
445            any(
446                node.target == aten.randn.default
447                for node in traced.graph.nodes
448            )
449        )
450
451    def test_pre_dispatch_functionalization(self):
452        def f(x):
453            a = FunctionalTensorMode(pre_dispatch=True)
454            with a:
455                x_unwrapped = FunctionalTensor.to_functional(x)
456                y = torch.matmul(x_unwrapped, x_unwrapped)
457                y = y + x_unwrapped
458                y.mul_(5)
459                y_unwrapped = torch._from_functional_tensor(y.elem)
460                return y_unwrapped
461
462        from torch._dispatch.python import enable_python_dispatcher
463
464        with enable_python_dispatcher():
465            inp = torch.randn(4, 4)
466            gm = make_fx(f, pre_dispatch=True)(inp)
467
468        # TODO actually not decompose
469        self.assertExpectedInline(gm.code.strip(), """\
470def forward(self, x_1):
471    matmul = torch.ops.aten.matmul.default(x_1, x_1)
472    add = torch.ops.aten.add.Tensor(matmul, x_1);  matmul = x_1 = None
473    mul = torch.ops.aten.mul.Tensor(add, 5);  add = None
474    return mul""")
475
476    def test_pre_dispatch_functionalization_view_op(self):
477        def f(x):
478            a = FunctionalTensorMode(pre_dispatch=True)
479            with a:
480                x_unwrapped = FunctionalTensor.to_functional(x)
481                y = torch.matmul(x_unwrapped, x_unwrapped)
482                x_unwrapped = x_unwrapped.transpose(1, 0)
483                y = y + x_unwrapped
484                y = y.view(2, 8)
485                y_unwrapped = torch._from_functional_tensor(y.elem)
486                return y_unwrapped
487
488        from torch._dispatch.python import enable_python_dispatcher
489
490        with enable_python_dispatcher():
491            inp = torch.randn(4, 4)
492            gm = make_fx(f, pre_dispatch=True)(inp)
493
494        # TODO actually not decompose
495        self.assertExpectedInline(gm.code.strip(), """\
496def forward(self, x_1):
497    matmul = torch.ops.aten.matmul.default(x_1, x_1)
498    transpose = torch.ops.aten.transpose.int(x_1, 1, 0);  x_1 = None
499    add = torch.ops.aten.add.Tensor(matmul, transpose);  matmul = transpose = None
500    view = torch.ops.aten.view.default(add, [2, 8]);  add = None
501    return view""")
502
503    def test_val_metadata_mutation(self):
504        def f(x):
505            y = x.clone()
506            y.unsqueeze_(0)
507            return y
508
509        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True))
510        self.assertEqual([
511            tuple(node.meta['val'].shape)
512            for node in traced.graph.nodes
513            if 'val' in node.meta
514        ], [(3,), (3,), (1, 3)])
515
516    def test_make_fx_overloads(self):
517        def f(x):
518            return x.cos() + torch.randn(x.shape)
519
520        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
521
522        self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload)
523                            for node in traced.graph.nodes if node.op == 'call_function'))
524
525    def test_tensor_constants(self):
526        def f():
527            val = torch.tensor(float('inf'))
528            return torch.full((100, 100), val)
529
530        self._test(f, [])
531
532    def test_allclose(self):
533        def f(a, b):
534            return torch.allclose(a, b)
535
536        def test_f():
537            make_fx(f, tracing_mode=self.tracing_mode)(
538                torch.zeros(3), torch.zeros(3)
539            )
540
541        if self.tracing_mode != "real":
542            self.assertRaises(DataDependentOutputException, test_f)
543        else:
544            self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
545
546    def test_constant_proxy_tensor_mut(self):
547        def f():
548            val = torch.tensor(float(1))
549            val.add_(2)
550            return torch.full((100, 100), val)
551
552        g = make_fx(f, tracing_mode=self.tracing_mode)()
553        self.assertEqual(g(), f())
554        # In case we mutated shared state in the g graph!
555        self.assertEqual(g(), f())
556
557    def test_constant_unbind(self):
558        def f():
559            val = torch.tensor([2])
560            r, = torch.unbind(val, 0)
561            return r.item()
562
563        g = make_fx(f, tracing_mode=self.tracing_mode)()
564        self.assertEqual(g(), f())
565
566    def test_constant_blowup(self):
567        def f():
568            val = torch.tensor([2])
569            blowup = val.repeat(1000)
570            return bool(blowup.sum().item() == 2)
571
572        def test_f():
573            make_fx(f, tracing_mode=self.tracing_mode)()
574
575        self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
576
577    def test_constant_random(self):
578        def f():
579            val = torch.tensor([2.0])
580            val.normal_()
581            return bool(val.item() == 2.1)
582
583        def test_f():
584            make_fx(f, tracing_mode=self.tracing_mode)()
585
586        self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
587
588    def test_decomposition_interpreter(self):
589        def fn(x):
590            return torch.nn.functional.silu(x)
591
592        x = torch.rand((4, 4))
593        fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
594
595        found_silu = False
596        for n in fx_module.graph.nodes:
597            if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
598                found_silu = True
599
600        self.assertTrue(found_silu)
601
602        new_graph = torch.fx.Graph()
603        silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
604        DecompositionInterpreter(
605            fx_module,
606            new_graph=new_graph,
607            decomposition_table=silu_decomp_table,
608        ).run(x)
609
610        decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
611
612        for n in decomposed_module.graph.nodes:
613            self.assertTrue(n.target != torch.ops.aten.silu)
614            self.assertTrue(n.target != torch.ops.aten.silu.default)
615
616        self.assertEqual(fx_module(x), decomposed_module(x))
617
618    def test_make_fx_model_fwd_bwd(self):
619        class Foo(torch.nn.Module):
620            def __init__(self) -> None:
621                super().__init__()
622                self.linear = torch.nn.Linear(5, 5)
623
624            def forward(self, x):
625                return self.linear(x).relu()
626
627        model = Foo()
628
629        def f(x, params):
630            out = torch.func.functional_call(model, params, x).sum()
631            out.backward()
632            return list(params.values())
633        input = torch.randn(3, 5, requires_grad=True)
634        params = dict(model.named_parameters())
635        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
636        # fx may change the order of parameters in list, so using set() to compare
637        self.assertTrue(
638            torch.allclose(fx_f(input, params)[0], f(input, params)[0])
639            or
640            torch.allclose(fx_f(input, params)[0], f(input, params)[1])
641        )
642        self.assertTrue(
643            torch.allclose(fx_f(input, params)[1], f(input, params)[0])
644            or
645            torch.allclose(fx_f(input, params)[1], f(input, params)[1])
646        )
647
648    def test_make_fx_model_double_param(self):
649        class Emformer(torch.nn.Module):
650            def __init__(
651                self,
652                input_dim: int = 256,
653            ) -> None:
654                super().__init__()
655
656                self.layer_norm = torch.nn.LayerNorm(input_dim)
657
658            def forward(mod_self, x):  # noqa: B902
659                self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
660                y = mod_self.layer_norm(x)
661                self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
662                z = mod_self.layer_norm(y)
663                return z
664
665
666        gm = make_fx(Emformer())(torch.randn(16, 1, 256))
667        ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
668        self.assertEqual(len(ops), 2)
669
670
671    def test_make_fx_model_fwd_bwd_wgtupdate(self):
672        class Foo(torch.nn.Module):
673            def __init__(self) -> None:
674                super().__init__()
675                self.linear = torch.nn.Linear(5, 5)
676
677            def forward(self, x):
678                return self.linear(x).relu()
679
680        model = Foo()
681
682        def f(args, params, buffers):
683            for p in params.values():
684                p.grad = None
685            if not isinstance(args, Iterable):
686                args = [args]
687            params_and_buffers = {**params, **buffers}
688            out = torch.func.functional_call(model, params_and_buffers, args)
689            out.sum().backward()
690            return [p - 1e-4 * p.grad for p in params.values()]
691
692        input = torch.randn(3, 5, requires_grad=True)
693        params = dict(model.named_parameters())
694        buffers = dict(model.named_buffers())
695        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
696        # fx may change the order of parameters in list, so using set() to compare
697        # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
698        self.assertTrue(
699            torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
700            or
701            torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
702        )
703        self.assertTrue(
704            torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
705            or
706            torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
707        )
708
709    def test_trace_subclasses(self):
710        def f1(x):
711            x = UnwrapTensor(x)
712            y = x * 2
713            return y
714
715        def f2(x):
716            wrapped = UnwrapTensor(x)
717            y = x * wrapped
718            return y
719
720        inp = [torch.randn(5)]
721        self._test(f1, inp)
722        self._test(f2, inp)
723
724    def test_partial_decomp(self):
725        def f(a, b, c):
726            x = torch.addmm(a, b, c)
727            y = torch.addmm(a, b, c, beta=2, alpha=1)
728            return x + y
729        inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)]
730        fx_g = make_fx(f)(*inps)
731
732        def addmm(a, b, c, beta=1, alpha=1):
733            if beta == 1 and alpha == 1:
734                return NotImplemented
735            return beta * a + alpha * (b @ c)
736
737        decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps)
738
739        self.assertEqual(fx_g(*inps), decomposed_fx(*inps))
740        self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
741        self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
742
743    def test_decomp_of_capture(self):
744        val = torch.randn(5)
745
746        def f(x):
747            return x.t() + val.t()
748
749        def nop(x):
750            return x.cos()
751
752        traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
753        self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
754
755
756    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
757    def test_amp_cache(self):
758        layer = torch.nn.Conv2d(3, 3, 3).cuda()
759
760        def f(x, w):
761            return torch.nn.functional.conv2d(x, w, stride=layer.stride)
762
763        inp = torch.randn(4, 3, 10, 10, device='cuda')
764        with torch.autocast('cuda'):
765            out_graph = make_fx(f)(inp, layer.weight).graph
766            out_graph2 = make_fx(f)(inp, layer.weight).graph
767
768        self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes))
769        for a, b in zip(out_graph.nodes, out_graph2.nodes):
770            self.assertEqual(a.op, b.op)
771
772    def test_strides(self):
773        def f(x):
774            self.assertTrue(x.is_contiguous())
775            self.assertFalse(x.is_contiguous(memory_format=torch.channels_last))
776            x = x.permute(0, 3, 1, 2)
777            self.assertFalse(x.is_contiguous())
778            self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
779            return x
780        make_fx(f)(torch.randn(2, 3, 4, 5))
781
782        def f(x):
783            self.assertTrue(x.is_contiguous())
784            y = x[:, 1]
785            self.assertFalse(y.is_contiguous())
786            y = x[:, ::2]
787            self.assertFalse(y.is_contiguous())
788            return x.cos()
789
790        make_fx(f)(torch.randn(2, 3, 4, 5))
791
792    def test_pr_86917(self):
793        # Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344
794        def f(a, b):
795            return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10)
796
797        self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
798
799class TestGenericProxyTensorReal(TestGenericProxyTensor):
800    tracing_mode = "real"
801
802
803class TestGenericProxyTensorFake(TestGenericProxyTensor):
804    tracing_mode = "fake"
805
806
807class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
808    tracing_mode = "symbolic"
809
810
811del TestGenericProxyTensor
812
813
814class TestRealProxyTensor(TestCase):
815    def test_error_on_data_dependent_ops(self):
816        def f():
817            x = torch.randn([])
818            y = torch.randn([])
819            assert torch.allclose(x * y, y * x)
820            z = float(x)
821            z2 = float(y)
822
823        # Smoke tests
824        make_fx(f, _error_on_data_dependent_ops=False)()
825        make_fx(f, pre_dispatch=True, _error_on_data_dependent_ops=False)()
826
827class TestFakeProxyTensor(TestCase):
828    def test_issue82547(self):
829        x = nn.Parameter(torch.randn(3, 3))
830
831        def f():
832            return torch.ops.aten.t.default(x)
833        self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")())
834
835        class A(torch.Tensor):
836            pass
837
838        x = A(torch.randn(3, 3))
839        self.assertRaisesRegex(TypeError, "Multiple dispatch failed", lambda: make_fx(f, tracing_mode="fake")())
840
841    def test_use_fake_and_tensor(self):
842        def f(x, y):
843            z = torch.tensor([2.0, 3.0])
844            return x + y + z
845
846        g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
847        x, y = torch.randn(2), torch.randn(2)
848        self.assertEqual(g(x, y), f(x, y))
849
850    def test_free_fake(self):
851        def f(x):
852            return torch.add(x, y)
853
854        with FakeTensorMode() as fake_mode:
855            y = torch.randn(2)
856            make_fx(f, tracing_mode="real")(torch.randn(2))
857
858    def test_fused_adam(self):
859        # See https://github.com/pytorch/pytorch/issues/99356
860        params = [torch.randn(10, 10) for _ in range(10)]
861        grads = [torch.randn(10, 10) for _ in range(10)]
862        exp_avgs = [torch.randn(10, 10) for _ in range(10)]
863        exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
864        max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
865        state_steps = [torch.tensor(0) for _ in range(10)]
866
867        def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
868            (new_params, _, _, _, _) = aten._fused_adam.default(
869                params,
870                grads,
871                exp_avgs,
872                exp_avg_sqs,
873                max_exp_avg_sqs,
874                state_steps,
875                lr=0.1,
876                beta1=0.9,
877                beta2=0.999,
878                weight_decay=0.01,
879                eps=1e-8,
880                amsgrad=False,
881                maximize=False,
882            )
883
884            for p, new_p in zip(params, new_params):
885                p.copy_(new_p)
886
887            return params
888
889        gm = make_fx(fused_adam, tracing_mode='fake')(
890            params,
891            grads,
892            exp_avgs,
893            exp_avg_sqs,
894            max_exp_avg_sqs,
895            state_steps,
896        )
897        ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
898        for n in gm.graph.nodes:
899            if n.op == "call_function" and n.target in ensure_ops_have_val:
900                self.assertIn('val', n.meta)
901
902    def test_alias(self):
903        def f(x):
904            return torch.ops.aten.alias(x)
905
906        r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
907        # NB: this should not have a detach call
908        self.assertExpectedInline(r, """\
909def forward(self, x_1):
910    alias = torch.ops.aten.alias.default(x_1);  x_1 = None
911    return alias""")
912
913    def test_meta(self):
914        def f(x):
915            a = x.cos()
916            b = torch.var_mean(a, dim=0)
917            c = b * 2
918            return c
919
920        out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5))
921        for n in out.graph.nodes:
922            if n.op == 'output':
923                continue
924            self.assertTrue('val' in n.meta)
925
926def _get_node(fx_g, cond):
927    for n in fx_g.graph.nodes:
928        if cond(n):
929            return n
930    raise AssertionError
931
932def _get_free_symbols(shape_env):
933    vars = tuple(shape_env.var_to_val.keys())
934    return len([var for var in vars if var not in shape_env.replacements])
935
936def _trace(f, *args):
937    inps = [torch.randn(arg) for arg in args]
938    return make_fx(f, tracing_mode="symbolic")(*inps)
939
940# TODO: Need to test the guards themselves specifically as well
941class TestSymbolicTracing(TestCase):
942    def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
943        """
944        Tests fn traced with trace_inputs against test_inputs
945        Also returns shape env
946        """
947        trace_inputs = [torch.randn(shape) for shape in trace_inputs]
948        traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs)
949        for input in test_inputs:
950            input = [torch.randn(shape) for shape in input]
951            rx, ry = traced_f(*input), fn(*input)
952            if assert_eq:
953                self.assertEqual(rx, ry)
954        return traced_f
955
956
957    def test_debug_interpreter(self):
958        import torch.library
959        from torch.library import Library
960
961        foo = Library("foo", "DEF")  # noqa: TOR901
962        foo.define("foo(Tensor self) -> Tensor")
963
964        # Operator where meta and cpu disagree on strides
965        @torch.library.impl(foo, "foo", "CPU")
966        def foo_cpu(x):
967            return x.clone().T
968
969        @torch.library.impl(foo, "foo", "Meta")
970        def foo_meta(x):
971            return x.clone()
972
973        def f(x):
974            return torch.ops.foo.foo.default(x)
975
976        gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2))
977        from torch._functorch.compilers import DebugInterpreter
978
979        interp = DebugInterpreter(gm)
980
981        # input mismatch is caught (indicates guard problem)
982        self.assertRaisesRegex(
983            AssertionError, r"3 != 1",
984            lambda: interp.run(torch.randn(3, 3).T),
985        )
986
987        # Catch the incorrect meta
988        self.assertRaisesRegex(
989            AssertionError, r"\(3, 1\) != \(1, 3\)",
990            lambda: interp.run(torch.randn(3, 3))
991        )
992
993    def test_int_input(self):
994        def f(x, y):
995            return x.view(y)
996
997        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 4), 12).code).strip()
998        self.assertExpectedInline(r, """\
999def forward(self, x_1, y_1):
1000    view = torch.ops.aten.view.default(x_1, [y_1]);  x_1 = y_1 = None
1001    return view""")
1002
1003    def test_resize_from_zero(self):
1004        def f(x, y):
1005            x.resize_(y.size(0))
1006
1007        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
1008        self.assertExpectedInline(r, """\
1009def forward(self, x_1, y_1):
1010    sym_size_int = torch.ops.aten.sym_size.int(y_1, 0);  y_1 = None
1011    resize_ = torch.ops.aten.resize_.default(x_1, [sym_size_int]);  x_1 = sym_size_int = resize_ = None
1012    return None""")
1013
1014    def test_broadcast_shapes(self):
1015        def f(x, y):
1016            return torch.functional.broadcast_shapes(x.size(), y.size()[0])
1017
1018        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(3, 1), torch.empty(5)).code).strip()
1019        self.assertExpectedInline(r, """\
1020def forward(self, x_1, y_1):
1021    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0);  x_1 = None
1022    sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0);  y_1 = None
1023    return (sym_size_int, sym_size_int_1)""")
1024
1025    def test_deduped_shape(self):
1026        def f(s0, s1, x, y):
1027            return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1028
1029        x = torch.empty(3, 1)
1030        y = torch.empty(5)
1031        from torch.fx.experimental.symbolic_shapes import ShapeEnv
1032        shape_env = ShapeEnv()
1033
1034        with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1035            x = fake_mode.from_tensor(x)
1036            y = fake_mode.from_tensor(y)
1037            r = str(make_fx(f, tracing_mode="real")(x.shape[0], y.shape[0], x, y).code).strip()
1038            self.assertExpectedInline(r, """\
1039def forward(self, s0_1, s1_1, x_1, y_1):
1040    empty = torch.ops.aten.empty.memory_format([s0_1], device = device(type='cpu'), pin_memory = False)
1041    return ((s0_1, s1_1), empty)""")
1042
1043    def test_non_deduped_shape(self):
1044        def f(x, y):
1045            return torch.functional.broadcast_shapes(x.size(), y.size()[0]), torch.empty(x.shape[0])
1046
1047        x = torch.empty(3, 1)
1048        y = torch.empty(5)
1049        from torch.fx.experimental.symbolic_shapes import ShapeEnv
1050        shape_env = ShapeEnv()
1051
1052        with FakeTensorMode(shape_env=shape_env, static_shapes=False) as fake_mode:
1053            x = fake_mode.from_tensor(x)
1054            y = fake_mode.from_tensor(y)
1055            r = str(make_fx(f, tracing_mode="real")(x, y).code).strip()
1056            self.assertExpectedInline(r, """\
1057def forward(self, x_1, y_1):
1058    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0);  x_1 = None
1059    sym_size_int_1 = torch.ops.aten.sym_size.int(y_1, 0);  y_1 = None
1060    empty = torch.ops.aten.empty.memory_format([sym_size_int], device = device(type='cpu'), pin_memory = False)
1061    return ((sym_size_int, sym_size_int_1), empty)""")
1062
1063    def test_unary(self):
1064        def f(x):
1065            assert x.shape[0] < 20
1066            return x.cos()
1067        test_inputs = []
1068        test_inputs.append([(2, 5)])
1069        test_inputs.append([(6, 8)])
1070        gm = self._test_dynamic(f, [(3, 4)], test_inputs)
1071        self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
1072        self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}")
1073        self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
1074        self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] <= 19""")
1075
1076    def test_repeat_interleave(self):
1077        def f(src_tokens, beam_size_src):
1078            return src_tokens.repeat_interleave(beam_size_src.size(0), 0)
1079
1080        prompt_size = 64
1081        vocab_size = 64
1082        batch_size = 4
1083        src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1084        gm = make_fx(f, tracing_mode="symbolic")(src_tokens, torch.randn(5))
1085        self.assertEqual(len(gm.shape_env.guards), 0)
1086
1087    def test_non_symint_size_spec(self):
1088        # this isn't really a proxy tensor test, but it's the most convenient
1089        # way to get a fake tensor with symbolic sizes
1090        def f(x):
1091            torch._C._non_sym_sizes(x)
1092            return x + 1
1093
1094        x = torch.randn(2, 3)
1095        make_fx(f, tracing_mode="symbolic")(x)
1096
1097    # https://github.com/pytorch/pytorch/issues/108195
1098    def test_symbolic_repeat_interleave(self):
1099        def f(y, x):
1100            return y.repeat_interleave(x, dim=1)
1101
1102        y = torch.tensor([[1, 2], [3, 4]])
1103        x = torch.tensor([2, 3])
1104        r = str(make_fx(f, tracing_mode="symbolic")(y, x).code).strip()
1105        self.assertExpectedInline(r, """\
1106def forward(self, y_1, x_1):
1107    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1);  x_1 = None
1108    index_select = torch.ops.aten.index_select.default(y_1, 1, repeat_interleave);  y_1 = repeat_interleave = None
1109    return index_select""")
1110
1111    def test_mod_gcd_unbacked(self):
1112        def f(_a, _b, _stride):
1113            a = _a.item()
1114            b = _b.item()
1115            stride = _stride.item()
1116            torch._check_is_size(a)
1117            torch._check_is_size(b)
1118            torch._check_is_size(stride)
1119            ta = torch.randn(a * stride)
1120            tb = torch.randn(b * stride)
1121            r = torch.cat([ta, tb])
1122            return r.view(a + b, stride)
1123
1124        _a = torch.tensor(30)
1125        _b = torch.tensor(20)
1126        _stride = torch.tensor(10)
1127        r = str(make_fx(f, tracing_mode="symbolic")(_a, _b, _stride).code).strip()
1128        self.assertExpectedInline(r, """\
1129def forward(self, _a_1, _b_1, _stride_1):
1130    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(_a_1);  _a_1 = None
1131    _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(_b_1);  _b_1 = None
1132    _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(_stride_1);  _stride_1 = None
1133    mul = _local_scalar_dense * _local_scalar_dense_2
1134    randn = torch.ops.aten.randn.default([mul], device = device(type='cpu'), pin_memory = False);  mul = None
1135    mul_1 = _local_scalar_dense_1 * _local_scalar_dense_2
1136    randn_1 = torch.ops.aten.randn.default([mul_1], device = device(type='cpu'), pin_memory = False);  mul_1 = None
1137    cat = torch.ops.aten.cat.default([randn, randn_1]);  randn = randn_1 = None
1138    add = _local_scalar_dense + _local_scalar_dense_1;  _local_scalar_dense = _local_scalar_dense_1 = None
1139    view = torch.ops.aten.view.default(cat, [add, _local_scalar_dense_2]);  cat = add = _local_scalar_dense_2 = None
1140    return view""")
1141
1142    def test_cumsum_unbacked(self):
1143        def f(x):
1144            y = x.item()
1145            z = torch.randn((3, y, 3))
1146            return z.cumsum(0)
1147
1148        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([5])).code).strip()
1149        self.assertExpectedInline(
1150            r, """\
1151def forward(self, x_1):
1152    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1153    randn = torch.ops.aten.randn.default([3, _local_scalar_dense, 3], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1154    cumsum = torch.ops.aten.cumsum.default(randn, 0);  randn = None
1155    return cumsum"""  # noqa: B950
1156        )
1157
1158
1159    def test_repeat_interleave_unbacked_output_size(self):
1160        def f(x, y):
1161            s = x.sum().item()
1162            return y.repeat_interleave(x, dim=0, output_size=s)
1163
1164        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor([2, 3]), torch.randn(2)).code).strip()
1165        self.assertExpectedInline(
1166            r, """\
1167def forward(self, x_1, y_1):
1168    sum_1 = torch.ops.aten.sum.default(x_1)
1169    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1);  sum_1 = None
1170    repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense);  x_1 = _local_scalar_dense = None
1171    index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave);  y_1 = repeat_interleave = None
1172    return index_select"""  # noqa: B950
1173        )
1174
1175    def test_arange_unbacked_output_size(self):
1176        def f(x):
1177            return torch.arange(0, x)
1178
1179        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10)).code).strip()
1180        self.assertExpectedInline(
1181            r, """\
1182def forward(self, x_1):
1183    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1184    arange = torch.ops.aten.arange.start(0, _local_scalar_dense, device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1185    return arange"""  # noqa: B950
1186        )
1187
1188    def test_adv_index_batch(self):
1189        def f(src_tokens):
1190            bsz, src_len = src_tokens.size()[:2]
1191            start_step = src_tokens.shape[1]
1192            beam_size = 1
1193            generate_size = 64
1194            max_len = src_len + generate_size
1195            tokens = torch.zeros(bsz * beam_size, max_len).to(src_tokens).long().fill_(0)
1196            tokens[:, :start_step] = src_tokens.repeat_interleave(beam_size, 0)
1197            return tokens
1198
1199        prompt_size = 64
1200        vocab_size = 64
1201        batch_size = 4
1202        src_tokens = torch.randint(1, vocab_size, (batch_size, prompt_size))
1203        gm = make_fx(f, tracing_mode="symbolic")(src_tokens)
1204        # Guards to rule out batch_size == sys.maxsize (wobbling between 2 and
1205        # 1 ok)
1206        self.assertEqual(len(gm.shape_env.guards), 1)
1207
1208    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1209    def test_cpu_scalar_cuda(self):
1210        # Extracted from wave2vec2
1211        def f(a, b):
1212            return (a * b) @ b
1213
1214        r = str(
1215            make_fx(f, tracing_mode="symbolic")(
1216                torch.tensor(1.0), torch.randn(2, 2, device='cuda')
1217            ).code
1218        ).strip()
1219        self.assertExpectedInline(r, """\
1220def forward(self, a_1, b_1):
1221    mul = torch.ops.aten.mul.Tensor(a_1, b_1);  a_1 = None
1222    mm = torch.ops.aten.mm.default(mul, b_1);  mul = b_1 = None
1223    return mm""")
1224
1225    def test_binary_broadcast(self):
1226        def f(a, b):
1227            c = a * b
1228            return c
1229
1230        test_inputs = []
1231        test_inputs.append([(1, 5), (3, 1)])
1232        test_inputs.append([(1, 4), (4, 1)])
1233        shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env
1234        assert len(shape_env.guards) == 0
1235
1236    def test_multiply_shape(self):
1237        def f(a):
1238            return torch.empty(a.shape[0] * 2)
1239
1240        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1241        self.assertExpectedInline(r, """\
1242def forward(self, a_1):
1243    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0);  a_1 = None
1244    mul = sym_size_int * 2;  sym_size_int = None
1245    empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False);  mul = None
1246    return empty""")
1247
1248    def test_item(self):
1249        def f(a):
1250            r = a.item()
1251            return r * a
1252
1253        r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip()
1254        self.assertExpectedInline(r, """\
1255def forward(self, a_1):
1256    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1)
1257    mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense);  a_1 = _local_scalar_dense = None
1258    return mul""")
1259
1260    def test_tensor_symfloat(self):
1261        def f(a):
1262            r = torch.tensor(a.size(0) ** 2.0)
1263            assert r.dtype is torch.float
1264            return r
1265
1266        gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2))
1267        r = str(gm.code).strip()
1268        # NB: this specializes, which is fine, the point is to make sure the
1269        # dtype inference is correct
1270        self.assertExpectedInline(r, """\
1271def forward(self, a_1):
1272    _tensor_constant0 = self._tensor_constant0
1273    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
1274    return lift_fresh_copy""")
1275        self.assertEqual(gm._tensor_constant0, torch.tensor(4.0))
1276
1277    def test_item_to_constructor(self):
1278        def f(a):
1279            r = a.item()
1280            return torch.empty(r)
1281
1282        r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
1283        self.assertExpectedInline(
1284            r, """\
1285def forward(self, a_1):
1286    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1);  a_1 = None
1287    empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1288    return empty"""  # noqa: B950
1289        )
1290
1291
1292    def test_setitem_symint(self):
1293        # from moco
1294        # https://github.com/pytorch/pytorch/issues/101939
1295        def f(x):
1296            x[0] = x.size(0)
1297            return x
1298
1299        r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(10)).code).strip()
1300        self.assertExpectedInline(
1301            r, """\
1302def forward(self, x_1):
1303    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
1304    scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  sym_size_int = None
1305    select = torch.ops.aten.select.int(x_1, 0, 0)
1306    copy_ = torch.ops.aten.copy_.default(select, scalar_tensor);  select = scalar_tensor = copy_ = None
1307    return x_1"""  # noqa: B950
1308        )
1309
1310    def test_dynamic_pointwise_scalar(self):
1311        def f(gravity, mask):
1312            gravity[mask, 0] = gravity[mask, 0] * -1
1313
1314        r = str(make_fx(f, tracing_mode="symbolic")(
1315            torch.randn((12, 4)),
1316            torch.randint(0, 2, (12,), dtype=torch.bool)
1317        ).code).strip()
1318        self.assertExpectedInline(r, """\
1319def forward(self, gravity_1, mask_1):
1320    select = torch.ops.aten.select.int(gravity_1, 1, 0)
1321    index = torch.ops.aten.index.Tensor(select, [mask_1]);  select = None
1322    mul = torch.ops.aten.mul.Tensor(index, -1);  index = None
1323    select_1 = torch.ops.aten.select.int(gravity_1, 1, 0);  gravity_1 = None
1324    index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul);  select_1 = mask_1 = mul = index_put_ = None
1325    return None""")
1326
1327    def test_reflect_r_over_x(self):
1328        def reflect_R_over_x(R):
1329            reflect = torch.eye(3, device=R.device)
1330            reflect[0, 0] = -1
1331            return reflect @ R @ reflect
1332
1333        def f(crop_camera, mask):
1334            crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
1335
1336        r = str(make_fx(f, tracing_mode="symbolic")(
1337            torch.randn((12, 3, 3)),
1338            torch.randint(0, 2, (12,), dtype=torch.bool)
1339        ).code).strip()
1340        self.assertExpectedInline(r, """\
1341def forward(self, crop_camera_1, mask_1):
1342    index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
1343    eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
1344    _tensor_constant0 = self._tensor_constant0
1345    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
1346    select = torch.ops.aten.select.int(eye, 0, 0)
1347    select_1 = torch.ops.aten.select.int(select, 0, 0);  select = None
1348    copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy);  select_1 = lift_fresh_copy = copy_ = None
1349    sym_size_int = torch.ops.aten.sym_size.int(index, 0)
1350    expand = torch.ops.aten.expand.default(eye, [sym_size_int, 3, 3])
1351    view = torch.ops.aten.view.default(expand, [sym_size_int, 3, 3]);  expand = None
1352    sym_size_int_1 = torch.ops.aten.sym_size.int(crop_camera_1, 1)
1353    sym_size_int_2 = torch.ops.aten.sym_size.int(crop_camera_1, 2)
1354    expand_1 = torch.ops.aten.expand.default(index, [sym_size_int, sym_size_int_1, sym_size_int_2]);  index = None
1355    view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]);  expand_1 = sym_size_int_1 = sym_size_int_2 = None
1356    bmm = torch.ops.aten.bmm.default(view, view_1);  view = view_1 = None
1357    view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]);  bmm = None
1358    mul_4 = sym_size_int * 3
1359    view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]);  view_2 = mul_4 = None
1360    mm = torch.ops.aten.mm.default(view_3, eye);  view_3 = eye = None
1361    view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]);  mm = sym_size_int = None
1362    index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4);  crop_camera_1 = mask_1 = view_4 = index_put_ = None
1363    return None""")  # noqa: B950
1364
1365    def test_unbacked_slice(self):
1366        def f(x, m):
1367            x = x[m]
1368            return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
1369
1370        make_fx(f, tracing_mode="symbolic")(
1371            torch.randn((12, 3, 3)),
1372            torch.randint(0, 2, (12,), dtype=torch.bool)
1373        )
1374
1375    @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
1376    def test_unbacked_batch_resnet(self):
1377        mod = torchvision.models.resnet18()
1378
1379        def f(x, mask, params, buffers):
1380            for p in itertools.chain([x, mask], params.values(), buffers.values()):
1381                for s in p.shape:
1382                    guard_int(s)
1383            x = x[mask]
1384            torch._check(x.shape[0] >= 1)
1385            for p in params.values():
1386                p.grad = None
1387            return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
1388
1389        make_fx(f, tracing_mode="symbolic")(
1390            torch.randn(3, 3, 250, 250),
1391            torch.randint(0, 2, (3,), dtype=torch.bool),
1392            dict(mod.named_parameters()),
1393            dict(mod.named_buffers()),
1394        )
1395
1396    def test_boolean_index(self):
1397        def f(images, handedness, valid):
1398            images = images[valid]
1399            handedness = handedness[valid]
1400            right_hand_mask = handedness == 1
1401            images[right_hand_mask] = images[right_hand_mask].flip(-1)
1402
1403        r = str(make_fx(f, tracing_mode="symbolic")(
1404            torch.randint(0, 256, (512, 1, 96, 96)),
1405            torch.randint(0, 1, (512,)),
1406            torch.randint(0, 2, (512,), dtype=torch.bool)
1407        ).code).strip()
1408        self.assertExpectedInline(r, """\
1409def forward(self, images_1, handedness_1, valid_1):
1410    index = torch.ops.aten.index.Tensor(images_1, [valid_1]);  images_1 = None
1411    index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]);  handedness_1 = valid_1 = None
1412    eq = torch.ops.aten.eq.Scalar(index_1, 1);  index_1 = None
1413    index_2 = torch.ops.aten.index.Tensor(index, [eq])
1414    flip = torch.ops.aten.flip.default(index_2, [-1]);  index_2 = None
1415    index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip);  index = eq = flip = index_put_ = None
1416    return None""")
1417
1418    def test_neg_shape(self):
1419        def f(a):
1420            return torch.empty(-a.shape[0] + 10)
1421
1422        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
1423        self.assertExpectedInline(r, """\
1424def forward(self, a_1):
1425    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0);  a_1 = None
1426    neg = -sym_size_int;  sym_size_int = None
1427    add = neg + 10;  neg = None
1428    empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False);  add = None
1429    return empty""")
1430
1431    def test_unbacked_unification(self):
1432        def f(x, y):
1433            z = torch.zeros(x.item())
1434            return z + y
1435
1436        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1437        self.assertExpectedInline(r, """\
1438def forward(self, x_1, y_1):
1439    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1440    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1441    add = torch.ops.aten.add.Tensor(zeros, y_1);  zeros = y_1 = None
1442    return add""")  # noqa: B950
1443
1444    def test_reshape_divisibility_unbacked(self):
1445        def f(x):
1446            i0 = x.item()
1447            r = torch.zeros(i0, 4, 20)
1448            r = r.transpose(2, 1)
1449            return r.reshape(-1, 80)
1450        make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
1451
1452    def test_view_divisibility_unbacked(self):
1453        def f(x):
1454            i0 = x.item()
1455            r = torch.zeros(i0, 192)
1456            return r.view(12, -1, 192)
1457        make_fx(f, tracing_mode="symbolic")(torch.tensor(24))
1458
1459    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1460    def test_view_divisibility_unbacked_relatively_prime(self):
1461        # See https://github.com/pytorch/pytorch/issues/123651
1462        def f(x):
1463            i0 = x.item()
1464            torch._check_is_size(i0)
1465            # To trigger the original issue, the max bound has to
1466            # be chosen such that 448 / 447 < 2 (which it is.)
1467            torch._check(i0 <= 448)
1468            return torch.zeros(256 * i0).view(-1, 447)
1469        make_fx(f, tracing_mode="symbolic")(torch.tensor(256 * 447, device="cuda"))
1470
1471    def test_unbacked_unify_guard(self):
1472        def f(x, y):
1473            z = torch.zeros(x.item())
1474            torch._check(z.size(0) == y.size(0))  # refines i0 = s0
1475            if z.size(0) == 4:
1476                return y * 2
1477            else:
1478                return y + 2
1479
1480        r = str(make_fx(f, tracing_mode="symbolic")(torch.tensor(10), torch.randn(10)).code).strip()
1481        self.assertExpectedInline(r, """\
1482def forward(self, x_1, y_1):
1483    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1484    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = zeros = None
1485    add = torch.ops.aten.add.Tensor(y_1, 2);  y_1 = None
1486    return add""")  # noqa: B950
1487
1488    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1489    @unittest.expectedFailure
1490    def test_unbacked_unify_guard_transitivity(self):
1491        def f(x1, x2, y):
1492            z1 = torch.zeros(x1.item())
1493            z2 = torch.zeros(x2.item())
1494            torch._check(z1.size(0) == z2.size(0))  # refines i0 = i1
1495            torch._check(z2.size(0) == y.size(0))  # refines i0 = s0
1496            if z1.size(0) == 4:
1497                return y * 2
1498            else:
1499                return y + 2
1500
1501        gm = make_fx(f, tracing_mode="symbolic")(
1502            torch.tensor(10, device="cuda"),
1503            torch.tensor(10, device="cuda"),
1504            torch.randn(10, device="cuda")
1505        )
1506        insert_deferred_runtime_asserts(gm, gm.shape_env, "test")
1507        gm.recompile()
1508        r = str(gm.code).strip()
1509        # self.assertExpectedInline(
1510        #     r, """"""  # noqa: B950
1511        # )
1512
1513    @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
1514    def test_unbacked_unify_dependency_violation(self):
1515        def f(x1, x2, x3, y):
1516            z1 = x1.item()
1517            torch._check(z1 // 9 == 1)
1518            z2 = x2.item()
1519            z3 = x3.item()
1520            torch._check(z1 == z2 + z3)
1521            return y * 2
1522            if z2 + z3 == z1:
1523                return y * 2
1524            else:
1525                return y + 3
1526
1527        # NB: inputs are done as CUDA to ensure they aren't queried to be
1528        # backed
1529
1530        gm = make_fx(f, tracing_mode="symbolic")(
1531            torch.tensor(10, device="cuda"), torch.tensor(5, device="cuda"),
1532            torch.tensor(5, device="cuda"), torch.randn(1, device="cuda")
1533        )
1534        insert_deferred_runtime_asserts(gm, gm.shape_env, "test")
1535        gm.recompile()
1536        self.assertEqual(gm(
1537            torch.tensor(12, device="cuda"), torch.tensor(6, device="cuda"),
1538            torch.tensor(6, device="cuda"), torch.tensor([1.0], device="cuda")),
1539            torch.tensor([2.0], device="cuda")
1540        )
1541        with self.assertRaises(RuntimeError):
1542            gm(
1543                torch.tensor(20, device="cuda"), torch.tensor(10, device="cuda"),
1544                torch.tensor(10, device="cuda"), torch.tensor([1.0], device="cuda")
1545            )
1546
1547
1548    def test_split_unbacked_sizes(self):
1549        def f(lengths, values):
1550            # tolist not directly supported atm
1551            sizes = [lengths[i].item() for i in range(lengths.size(0))]
1552            for s in sizes:
1553                # TODO(avik): no assertion generated with torch._check_is_size?
1554                torch._constrain_as_size(s)
1555            return torch.split(values, sizes)
1556
1557        r = str(make_fx(f, tracing_mode="symbolic")(
1558            torch.tensor([2, 3, 4]),
1559            torch.randn(9)
1560        ).code).strip()
1561        self.assertExpectedInline(r, """\
1562def forward(self, lengths_1, values_1):
1563    select = torch.ops.aten.select.int(lengths_1, 0, 0)
1564    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(select);  select = None
1565    select_1 = torch.ops.aten.select.int(lengths_1, 0, 1)
1566    _local_scalar_dense_1 = torch.ops.aten._local_scalar_dense.default(select_1);  select_1 = None
1567    select_2 = torch.ops.aten.select.int(lengths_1, 0, 2);  lengths_1 = None
1568    _local_scalar_dense_2 = torch.ops.aten._local_scalar_dense.default(select_2);  select_2 = None
1569    sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense);  sym_constrain_range_for_size = None
1570    sym_constrain_range_for_size_1 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_1);  sym_constrain_range_for_size_1 = None
1571    sym_constrain_range_for_size_2 = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense_2);  sym_constrain_range_for_size_2 = None
1572    split_with_sizes = torch.ops.aten.split_with_sizes.default(values_1, [_local_scalar_dense, _local_scalar_dense_1, _local_scalar_dense_2]);  values_1 = _local_scalar_dense = _local_scalar_dense_1 = _local_scalar_dense_2 = None
1573    getitem = split_with_sizes[0]
1574    getitem_1 = split_with_sizes[1]
1575    getitem_2 = split_with_sizes[2];  split_with_sizes = None
1576    return (getitem, getitem_1, getitem_2)""")  # noqa: B950
1577
1578    def test_invalidate_nonzero(self):
1579        ok = False
1580
1581        def f(a):
1582            nonlocal ok
1583            b = a.clone()
1584            x = b.nonzero()
1585            x1 = b.nonzero()
1586            x2 = b.nonzero()
1587            assert x1.shape[0] == x2.shape[0]
1588            ok = True
1589            b.normal_()
1590            y = b.nonzero()
1591            try:
1592                bool(x1.shape[0] == y.shape[0])
1593                self.fail("didn't raise exception")
1594            except GuardOnDataDependentSymNode:
1595                pass
1596
1597        make_fx(f, tracing_mode="symbolic")(torch.randn(4))
1598
1599    @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
1600    def test_invalidate_nonzero_propagate_real_tensors(self):
1601        def f(a):
1602            b = a.clone()
1603            x = b.nonzero()
1604            x1 = b.nonzero()
1605            x2 = b.nonzero()
1606            assert x1.shape[0] == x2.shape[0]
1607            b.normal_()
1608            y = b.nonzero()
1609            # Because you're not actually going to generate exactly zero with
1610            # normal_ lol
1611            assert x1.shape[0] == y.shape[0]
1612
1613        make_fx(f, tracing_mode="symbolic")(torch.randn(4))
1614
1615    def test_sqrt_size(self):
1616        def f(a):
1617            return a / a.size(-1) ** 0.5
1618
1619        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1620        self.assertExpectedInline(r, """\
1621def forward(self, a_1):
1622    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1623    sym_float = torch.sym_float(sym_size_int);  sym_size_int = None
1624    pow_1 = sym_float ** 0.5;  sym_float = None
1625    div = torch.ops.aten.div.Tensor(a_1, pow_1);  a_1 = pow_1 = None
1626    return div""")
1627
1628    def test_make_fx_with_custom_tracer_preserving_nn_module_stack(self):
1629
1630        class Bar(torch.nn.Module):
1631            def __init__(self) -> None:
1632                super().__init__()
1633
1634            def forward(self, x):
1635                return x + 1
1636
1637        class Foo(torch.nn.Module):
1638            def __init__(self) -> None:
1639                super().__init__()
1640                self.bar = Bar()
1641
1642            def forward(self, x):
1643                return x + self.bar(x)
1644
1645        gm = make_fx(Foo())(torch.randn(4, 4))
1646        for node in gm.graph.nodes:
1647            self.assertTrue("nn_module_stack" not in node.meta)
1648
1649        foo = Foo()
1650
1651        def functional_call(*args, **kwargs):
1652            with stateless._reparametrize_module(foo, {}):
1653                return foo(*args, **kwargs)
1654
1655        functional_call._orig_mod = foo
1656
1657        gm_with_stack = make_fx(functional_call, record_module_stack=True)(torch.randn(4, 4))
1658        found = False
1659        for node in gm_with_stack.graph.nodes:
1660            if "nn_module_stack" in node.meta:
1661                if len(node.meta["nn_module_stack"]) == 1:
1662                    self.assertTrue("custom_tracer_preserving_nn_module_stack.<locals>.Foo" in str(node.meta["nn_module_stack"]))
1663                    found = True
1664                elif len(node.meta["nn_module_stack"]) == 2:
1665                    self.assertTrue("preserving_nn_module_stack.<locals>.Bar" in str(node.meta["nn_module_stack"]))
1666                    found = True
1667                else:
1668                    # there can be at most 2 level
1669                    self.assertTrue(False)
1670
1671        self.assertTrue(found)
1672
1673        gm_without_stack = make_fx(functional_call)(torch.randn(4, 4))
1674        for node in gm_without_stack.graph.nodes:
1675            self.assertTrue("nn_module_stack" not in node.meta)
1676
1677    def test_symint_to_tensor(self):
1678        def f(a):
1679            return a / a.shape[0]
1680
1681        r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1682        self.assertExpectedInline(r, """\
1683def forward(self, a_1):
1684    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1685    div = torch.ops.aten.div.Tensor(a_1, sym_size_int);  a_1 = sym_size_int = None
1686    return div""")
1687
1688        r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
1689        self.assertExpectedInline(r, """\
1690def forward(self, a_1):
1691    sym_size_int = torch.ops.aten.sym_size.int(a_1, 0)
1692    sym_float = torch.sym_float(sym_size_int);  sym_size_int = None
1693    div = torch.ops.prims.div.default(a_1, sym_float);  a_1 = sym_float = None
1694    return div""")
1695
1696    def test_cat(self):
1697        def f(a, b):
1698            val = torch.mul(a, b)
1699            out = torch.cat([val, val])
1700            if out.shape[0] * out.shape[1] > 20:
1701                out = out.cos()
1702            return out
1703
1704        test_inputs = []
1705        test_inputs.append([(1, 5), (6, 1)])
1706        test_inputs.append([(1, 4), (3, 1)])
1707        gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
1708        self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
1709        self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
1710        self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""")
1711
1712    def test_new_empty(self):
1713        def f(a, b):
1714            return a.new_empty(b.shape[0], b.shape[1] * 2)
1715
1716        self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env
1717
1718    def test_size_with_tensor(self):
1719        # I think I messed up writing this test case originally, I think
1720        # I'm supposed to hit an error case, but the code here works in both
1721        # eager and tracing
1722        def f(tensor):
1723            max_size = torch.tensor([800, 1216], dtype=torch.int64)
1724            batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
1725            return tensor.new_empty(batch_shape)
1726
1727        a = torch.randn(3, 800, 1199)
1728        f(a)
1729        make_fx(f, tracing_mode="symbolic")(a)
1730
1731    def test_fake_tensor_as_size(self):
1732        def f(x):
1733            r = torch.zeros([x])
1734            return r
1735
1736        fx_g = make_fx(f, tracing_mode="symbolic")(torch.tensor(4))
1737        self.assertExpectedInline(fx_g.code.strip(), """\
1738def forward(self, x_1):
1739    _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1);  x_1 = None
1740    zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
1741    return zeros""")  # noqa: B950
1742
1743    def test_expand(self):
1744        def f(a):
1745            b = torch.mul(a, a)
1746            c = b.expand(a.shape)
1747            return c
1748
1749        self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
1750        self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
1751
1752    def test_metadata(self):
1753        def f(a, b):
1754            d = a.new_empty(a.shape[0] + b.shape[0])
1755            return d
1756        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
1757        meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
1758        meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
1759        self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr)
1760
1761    def test_metadata_fresh(self):
1762        def f(x):
1763            assert x.shape[0] == 3
1764            return x.cos()
1765
1766        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
1767        meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
1768        meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
1769        self.assertTrue(meta_cos.meta['val'].shape[0] == 3)
1770        # Checks if the input expr has been updated even though the constraint
1771        # happened afterwards
1772        self.assertTrue(meta_inp.meta['val'].shape[0] == 3)
1773
1774    def test_elementwise_meta_with_sym_numbers(self):
1775        def f(x, offset, as_sym_float=False):
1776            x0 = x.size()[0]
1777            if as_sym_float:
1778                x0 = torch.sym_float(x0)
1779            return torch.add(x0, offset)
1780
1781        fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False)
1782        meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1783        self.assertEqual(meta_add.meta['val'].shape, ())
1784        self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1785
1786        fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False)
1787        meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1788        self.assertEqual(meta_add.meta['val'].shape, ())
1789        self.assertEqual(meta_add.meta['val'].dtype, torch.int64)
1790
1791        fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True)
1792        meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1793        self.assertEqual(meta_add.meta['val'].shape, ())
1794        self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
1795
1796    def test_return_symint(self):
1797        def f(x):
1798            return x.shape[0], x.cos(), x.shape[0] / 5
1799        self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1800
1801        def f(x):
1802            return x.shape
1803        self._test_dynamic(f, [(5, 3)], [[(4, 6)]])
1804
1805    def test_rmethod(self):
1806        def f(x):
1807            return x.size(0) + x
1808        self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1809
1810    def test_mega_guard(self):
1811        def f(a, b):
1812            assert a.shape[0] == b.shape[0] * 2
1813            return a.cos()
1814        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
1815        from torch._dynamo.source import LocalSource
1816        self.assertExpectedInline(
1817            str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)),  # noqa: B950
1818            """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]"""  # noqa: B950
1819        )
1820        self.assertExpectedInline(
1821            str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)),  # noqa: B950
1822            """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]"""  # noqa: B950
1823        )
1824
1825    def test_guard_upperbound_range_refinement(self):
1826        def f(a):
1827            assert a.shape[0] > 5 and a.shape[0] > 12
1828            return a.cos()
1829        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1830        self.assertExpectedInline(show_guards(tensor), """13 <= L['a'].size()[0]""")
1831
1832    def test_guard_lowerbound_range_refinement(self):
1833        def f(a):
1834            assert a.shape[0] < 20 and a.shape[0] < 30
1835            return a.cos()
1836        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
1837        self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] <= 19""")
1838
1839    def test_guard_upperbound_range_refinement_multivariate(self):
1840        def f(a):
1841            assert a.shape[0] > 5 and a.shape[0] > 12
1842            assert a.shape[1] > 5 and a.shape[1] > a.shape[0]
1843            return a.cos()
1844        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20)))
1845        self.assertExpectedInline(show_guards(tensor), """\
1846L['a'].size()[1] > L['a'].size()[0]
184713 <= L['a'].size()[0]
184814 <= L['a'].size()[1]""")
1849
1850    def test_guard_lowerbound_range_refinement_multivariate(self):
1851        def f(a):
1852            assert a.shape[0] < 20 and a.shape[0] < 30
1853            assert a.shape[1] < 30 and a.shape[1] < a.shape[0]
1854            return a.cos()
1855        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5)))
1856        self.assertExpectedInline(
1857            show_guards(tensor),
1858            """\
1859L['a'].size()[1] < L['a'].size()[0]
1860L['a'].size()[0] <= 19
1861L['a'].size()[1] <= 18""")
1862
1863    def test_sym_storage_offset(self):
1864        def f(x, y):
1865            return x + y
1866
1867        inp = (torch.randn(8)[3:], torch.randn(5))
1868        fx_g = make_fx(f, tracing_mode="symbolic")(*inp)
1869        inp = (torch.randn(8)[3:], torch.randn(5))
1870        self.assertEqual(fx_g(*inp), f(*inp))
1871
1872    def _assert_no_guards(self, fx_g, free_symbols):
1873        assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
1874        assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
1875
1876    def test_guards_equal(self):
1877        def f(a, b):
1878            return a * b
1879
1880        # NB: Numbers are carefully chosen to avoid duck shaping from applying
1881
1882        fx_g = _trace(f, (5, 6), (5, 6))
1883        self._assert_no_guards(fx_g, 2)
1884
1885        fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
1886        self._assert_no_guards(fx_g, 3)
1887
1888        fx_g = _trace(f, (5, 1), (1, 6))
1889        self._assert_no_guards(fx_g, 2)
1890
1891        def f(a, b, c, d):
1892            a = a + b
1893            cat = torch.cat([c, d])
1894            return a + cat
1895
1896        fx_g = _trace(f, 7, 7, 4, 3)
1897        self._assert_no_guards(fx_g, 2)
1898
1899        def f(a, b, c, d, e):
1900            vals = [a, b, c, d, e]
1901            x = a
1902            for idx in range(len(vals) - 1):
1903                x = torch.cat([x, vals[idx]]) + vals[idx + 1]
1904            return x
1905
1906        fx_g = _trace(f, 2, 4, 8, 16, 32)
1907        self._assert_no_guards(fx_g, 1)
1908
1909        def f(a, b):
1910            a = a.view(b.shape[0])
1911            return a + b.sum()
1912
1913        fx_g = _trace(f, (4, 2), 8)
1914        self._assert_no_guards(fx_g, 2)
1915
1916        fx_g = _trace(f, (4, 2), (8, 5))
1917        self._assert_no_guards(fx_g, 3)
1918
1919        fx_g = _trace(f, (2, 3, 4), 24)
1920        self._assert_no_guards(fx_g, 3)
1921
1922    def test_nonidentity_transitive_guards(self):
1923        def f(a, b, c, d, e):
1924            vals = [a, b, c, d, e]
1925            cat_vals = []
1926            for idx in range(len(vals) - 1):
1927                cat_vals.append(torch.cat([vals[idx], vals[idx]]))
1928            final_vals = []
1929            for a, b in reversed(list(zip(cat_vals, vals[1:]))):
1930                final_vals.append(a + b)
1931            return final_vals
1932
1933        fx_g = _trace(f, 2, 4, 8, 16, 32)
1934        self.assertExpectedInline(show_guards(fx_g), """""")
1935
1936    @torch.fx.experimental._config.patch(translation_validation=True)
1937    def test_constant_specialization(self):
1938        def f(t):
1939            assert t.shape[0] == 10
1940            return t
1941
1942        tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(10))
1943        self.assertExpectedInline(show_guards(tensor), """""")
1944
1945
1946make_fx_failures = {
1947    # unknown
1948    xfail('allclose'),
1949    xfail('equal'),
1950    # empty
1951    skip('new_empty'),
1952    skip('empty_like'),
1953    skip('empty'),
1954    skip('empty_permuted'),
1955    # flaky
1956    skip('linalg.lstsq', 'grad_oriented'),
1957    skip('nn.functional.max_unpool1d', '', device_type='cpu'),
1958    skip('nn.functional.max_unpool2d', '', device_type='cpu'),
1959    skip('nn.functional.max_unpool3d', '', device_type='cpu'),
1960    skip('linalg.lstsq'),  # flaky, probably just a precision issue
1961
1962    # data-dependent control flow
1963    skip('item'),
1964    xfail('cov'),
1965    xfail('nn.functional.gaussian_nll_loss'),
1966    xfail('tensor_split'),
1967    xfail('corrcoef'),
1968    xfail('quantile'),
1969    xfail('nanquantile'),
1970
1971    # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
1972    xfail('sparse.sampled_addmm'),
1973    xfail('sparse.mm', 'reduce'),
1974
1975    # proxy tensor doesn't support sparse correctly right now
1976    skip('to_sparse'),
1977    # segfaults
1978    skip('block_diag'),
1979
1980    # AssertionError: Tensor-likes are not close!
1981    skip('empty_strided', '', device_type='cpu'),
1982}
1983
1984only_real_tensor_failures = {
1985    xfail('narrow'),
1986}
1987
1988only_fake_tensor_failures = {
1989    xfail('narrow'),
1990}
1991
1992fake_tensor_failures = {
1993    # ASAN failures due to divide by 0
1994    skip('nn.functional.nll_loss'),
1995}
1996
1997symbolic_tensor_failures = {
1998    xfail('combinations', ''),
1999    xfail('geqrf', ''),  # aten.geqrf.default - couldn't find symbolic meta function/decomposition
2000    xfail('histogram', ''),  # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...
2001    xfail('histogramdd', ''),  # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
2002    xfail('nanquantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
2003    xfail('nn.functional.binary_cross_entropy', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decom...
2004    xfail('nn.functional.cross_entropy', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
2005    xfail('nn.functional.ctc_loss'),  # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
2006    xfail('quantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
2007    xfail('unique_consecutive', ''),  # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
2008
2009    xfail('max_pool2d_with_indices_backward', ''),  # Expected a value of type 'List[int]' for argument 'kernel_size' but...
2010
2011    # many complex operators incorrect striding, metadata
2012    xfail('fft.fft', ''),
2013    xfail('fft.hfft2', ''),
2014    xfail('fft.hfft', ''),
2015    xfail('fft.hfftn', ''),
2016    xfail('fft.ifft', ''),
2017    xfail('fft.ihfft2', ''),
2018    xfail('fft.ihfft', ''),
2019    xfail('fft.ihfftn', ''),
2020    xfail('fft.ihfft2', ''),
2021    xfail('fft.irfft2', ''),
2022    xfail('fft.irfft', ''),
2023    xfail('fft.irfftn', ''),
2024    xfail('fft.rfft2', ''),
2025    xfail('fft.rfft', ''),
2026    xfail('fft.rfftn', ''),
2027    xfail('stft', '')
2028}
2029symbolic_tensor_segfaults = {
2030    skip('nn.functional.batch_norm')  # Segfault??
2031}
2032
2033symbolic_tensor_failures.update(symbolic_tensor_segfaults)
2034
2035inplace_symbolic_tensor_failures = {
2036    # bugs
2037    xfail('float_power', ''),  # base given to float_power_ has dtype Float but the operation's result requires dtype Double
2038}
2039
2040out_symbolic_tensor_failures = {
2041    # Cast error details: Unable to cast (...) to Tensor
2042    #
2043    # This happens because the test is set up to call the out variant using the `out` kwarg:
2044    #   torch._some_op(arg1, arg2, out=(out1, out2, out3))
2045    #
2046    # However, this only works on torch ops, not aten ops. For `_batch_norm_with_update`,
2047    # this fails because the op has no python bindings, so it doesn't support the `out` kwarg
2048    # way of calling its out variant.
2049    xfail('_batch_norm_with_update', ''),
2050    xfail('_native_batch_norm_legit', ''),
2051    xfail('angle', ''),
2052    xfail('argmax', ''),
2053    xfail('argmin', ''),
2054    xfail('fft.fft2', ''),
2055    xfail('fft.fftn', ''),
2056    xfail('fft.ifft2', ''),
2057    xfail('fft.ifftn', ''),
2058    xfail('gather', ''),
2059    xfail('linalg.pinv', ''),
2060    xfail('linalg.pinv', 'hermitian'),
2061    xfail('lu', ''),
2062    xfail('scatter_add', ''),
2063    xfail('scatter', ''),
2064    xfail('take_along_dim', ''),
2065    xfail('triangular_solve', ''),
2066
2067    # SymIntArrayRef expected to contain only concrete
2068    xfail('ones', ''),
2069    xfail('randn', ''),
2070    xfail('zeros', ''),
2071
2072    # RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides
2073    xfail('index_reduce', 'prod'),
2074    xfail('index_reduce', 'mean'),
2075    xfail('index_reduce', 'amax'),
2076    xfail('index_reduce', 'amin'),
2077}
2078
2079out_symbolic_tensor_segfaults = {
2080    skip('nanmean', ''),
2081}
2082
2083out_symbolic_tensor_failures.update(out_symbolic_tensor_segfaults)
2084
2085# Copies inputs to inplace operations to avoid inplace modifications
2086#   to leaves requiring gradient
2087def _get_safe_inplace(inplace_variant):
2088    @functools.wraps(inplace_variant)
2089    def _fn(t, *args, **kwargs):
2090        return inplace_variant(t.clone(), *args, **kwargs)
2091
2092    return _fn
2093
2094def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False, out=False):
2095    fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op
2096    sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
2097
2098    # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
2099    count = 100
2100    if out:
2101        count = 5
2102    for sample_input in itertools.islice(sample_inputs_itr, count):
2103        if inplace and sample_input.broadcasts_input:
2104            continue
2105        args = [sample_input.input] + list(sample_input.args)
2106        kwargs = sample_input.kwargs
2107        if out:
2108            expected = fn(*args, **kwargs)
2109            kwargs['out'] = expected
2110
2111        try:
2112            optests.make_fx_check(fn, args, kwargs, tracing_mode, self.assertEqual,
2113                                  randomize_data=True)
2114        except DynamicOutputShapeException:
2115            self.skipTest("Dynamic output shape operation in trace")
2116
2117
2118def skipIfNameMatches(pattern):
2119    """
2120    Decorator to skip a test if its name matches the given pattern.
2121    """
2122    def decorator(test_func):
2123        def wrapper(*args, **kwargs):
2124            if re.match(pattern, test_func.__name__):
2125                raise unittest.SkipTest(f"Test '{test_func.__name__}' skipped because its name matches the pattern '{pattern}'")
2126            return test_func(*args, **kwargs)
2127        return wrapper
2128    return decorator
2129
2130# Auto functionalize shouldn't work with make_fx directly
2131filtered_hop_db = [op for op in hop_db if op.name != "auto_functionalize"]
2132
2133@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond requires dynamo")
2134class TestProxyTensorOpInfo(TestCase):
2135    @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
2136    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures.union(only_real_tensor_failures))
2137    def test_make_fx_exhaustive(self, device, dtype, op):
2138        _test_make_fx_helper(self, device, dtype, op, "real")
2139
2140    @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
2141    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive',
2142             make_fx_failures.union(fake_tensor_failures, only_fake_tensor_failures))
2143    def test_make_fx_fake_exhaustive(self, device, dtype, op):
2144        _test_make_fx_helper(self, device, dtype, op, "fake")
2145
2146    @ops(op_db + filtered_hop_db + custom_op_db, allowed_dtypes=(torch.float,))
2147    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
2148             make_fx_failures | fake_tensor_failures | symbolic_tensor_failures)
2149    def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
2150        _test_make_fx_helper(self, device, dtype, op, "symbolic")
2151
2152    @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2153    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace',
2154             make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures)
2155    def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op):
2156        if not op.get_inplace():
2157            self.skipTest("No inplace variable for this op")
2158        _test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True)
2159
2160    @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
2161    @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out',
2162             make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | out_symbolic_tensor_failures)
2163    def test_make_fx_symbolic_exhaustive_out(self, device, dtype, op):
2164        if not op.supports_out:
2165            self.skipTest("Op doesn't support out")
2166        _test_make_fx_helper(self, device, dtype, op, "symbolic", out=True)
2167
2168
2169only_for = ("cpu")
2170instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
2171
2172
2173if __name__ == '__main__':
2174    run_tests()
2175