xref: /aosp_15_r20/external/pytorch/test/dynamo/test_misc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import abc
3import collections
4import copy
5import dataclasses
6import dis
7import enum
8import functools
9import gc
10import itertools
11import logging
12import math
13import operator
14import os
15import random
16import sys
17import tempfile
18import threading
19import traceback
20import typing
21import unittest
22import unittest.mock as mock
23import warnings
24import weakref
25from unittest.mock import patch
26
27import numpy as np
28
29import torch
30import torch._dynamo.testing
31
32import torch._inductor.test_case
33import torch.onnx.operators
34
35import torch.utils._pytree as pytree
36import torch.utils.cpp_extension
37from torch import Tensor
38from torch._C import FileCheck
39from torch._dynamo import allow_in_graph
40from torch._dynamo.eval_frame import _debug_get_cache_entry_list
41from torch._dynamo.exc import Unsupported
42from torch._dynamo.source import ConstantSource, GetItemSource, LocalSource
43from torch._dynamo.testing import (
44    CompileCounter,
45    CompileCounterWithBackend,
46    expectedFailureDynamic,
47    same,
48    skipIfNotPy311,
49    unsupported,
50    xfailIfPy312,
51)
52from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault
53from torch._inductor.utils import run_and_get_code
54from torch.ao.quantization import MinMaxObserver
55from torch.ao.quantization.fake_quantize import FakeQuantize
56from torch.ao.quantization.qconfig import QConfig
57from torch.ao.quantization.quantize_fx import prepare_qat_fx
58from torch.fx.experimental.recording import NotEqualError, replay_shape_env_events
59from torch.fx.experimental.symbolic_shapes import (
60    _constrain_range_for_size,
61    constrain_range,
62    constrain_unify,
63    ConstraintViolationError,
64    expect_true,
65    guard_size_oblivious,
66    ShapeEnv,
67)
68from torch.nn import functional as F
69from torch.testing import make_tensor
70from torch.testing._internal.common_cuda import (
71    PLATFORM_SUPPORTS_FLASH_ATTENTION,
72    SM80OrLater,
73    TEST_CUDA,
74    TEST_MULTIGPU,
75)
76from torch.testing._internal.common_methods_invocations import (
77    sample_inputs_take_along_dim,
78)
79from torch.testing._internal.common_utils import (
80    freeze_rng_state,
81    IS_FBCODE,
82    set_default_dtype,
83    wrapDeterministicFlagAPITest,
84)
85from torch.testing._internal.jit_utils import JitTestCase
86from torch.testing._internal.logging_utils import logs_to_string
87
88mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"])
89T = typing.TypeVar("T")
90
91
92# Specializes a test to run only if translation validation is set.
93def onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable:
94    @functools.wraps(fn)
95    def wrapper(*args, **kwargs):
96        import torch.fx.experimental.validator
97
98        if torch.fx.experimental.validator.translation_validation_enabled():
99            return fn(*args, **kwargs)
100        raise unittest.SkipTest(f"only works when TV is True.")
101
102    return wrapper
103
104
105def cleanup_op(opname):
106    ns, name = opname.split("::")
107    if not hasattr(torch.ops, ns):
108        return
109    actual_ns = getattr(torch.ops, ns)
110    if not hasattr(actual_ns, name):
111        return
112    delattr(actual_ns, name)
113
114
115class MyPickledModule(torch.nn.Module):
116    def __init__(self, z):
117        super().__init__()
118        self.z = z
119
120    def forward(self, x, y):
121        return x * x * x + y + self.z
122
123
124# These are used for test_{cond/map}_with_quantization
125default_symmetric_fake_quant = FakeQuantize.with_args(
126    observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
127)
128default_weight_symmetric_fake_quant = FakeQuantize.with_args(
129    observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
130)
131uniform_qconfig_8bit = QConfig(
132    activation=default_symmetric_fake_quant,
133    weight=default_weight_symmetric_fake_quant.with_args,
134)
135qconfig_dict = {"object_type": [(torch.nn.Linear, uniform_qconfig_8bit)]}
136
137
138def closure_adder(val):
139    def inner(x):
140        return torch.sin(x + val)
141
142    return inner
143
144
145class UserDefineSetAttr:
146    setup = False
147
148    def __setattr__(self, key, value):
149        assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup
150        super().__setattr__(f"pfx_{key}", value)
151
152    def __getattr__(self, key, c=1):
153        assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup
154        # c is added to force a guard on __defaults__ and checks the source for __getattr__
155        if c:
156            return self.__dict__[f"pfx_{key}"]
157        else:
158            return None
159
160
161class MiscTests(torch._inductor.test_case.TestCase):
162    def test_get_cache_entry(self):
163        def f(x):
164            return x + 1
165
166        torch.compile(f)(torch.randn(5, 5, 5))
167        entries = _debug_get_cache_entry_list(f)
168        self.assertTrue(len(entries) > 0)
169
170        def g(x):
171            return x + 2
172
173        entries = _debug_get_cache_entry_list(g)
174        self.assertTrue(len(entries) == 0)
175
176        try:
177            _debug_get_cache_entry_list(1)
178        except TypeError as e:
179            self.assertIn("expected a code object!", str(e))
180
181        # test get cache entry on skipped code object
182        def h(x):
183            x = x + 1
184            torch._dynamo.graph_break()
185            return x + 1
186
187        torch.compile(h)(torch.randn(3, 3))
188
189        entries = _debug_get_cache_entry_list(torch._dynamo.graph_break)
190        self.assertEqual(len(entries), 0)
191
192    def test_boolarg(self):
193        def boolarg(aa, bb, flag):
194            if flag:
195                return aa - bb
196            else:
197                return bb - aa
198
199        a = torch.randn(10, 10)
200        b = torch.randn(10, 10)
201        correct1 = boolarg(a, b, True)
202        correct2 = boolarg(a, b, False)
203        correct3 = boolarg(a, b, None)
204        counter = CompileCounter()
205        opt_boolarg = torch._dynamo.optimize_assert(counter)(boolarg)
206        val1 = opt_boolarg(a, b, True)
207        val2 = opt_boolarg(a, b, False)
208        val3 = opt_boolarg(a, b, None)
209        val4 = opt_boolarg(a, b, True)
210        self.assertTrue(same(val1, correct1))
211        self.assertTrue(same(val2, correct2))
212        self.assertTrue(same(val3, correct3))
213        self.assertTrue(same(val4, correct1))
214        self.assertEqual(counter.frame_count, 3)
215
216    def test_invalid_args_builtin(self):
217        @torch.compile(backend="eager")
218        def fn(x):
219            x = x.sin()
220            if isinstance(x, torch.Tensor, invalid=True):
221                x = x.sin()
222            return x
223
224        with self.assertRaises(TypeError):
225            fn(torch.randn(16))
226
227    def test_cpp_extension_recommends_custom_ops(self):
228        cpp_source = """
229        #include <torch/extension.h>
230        at::Tensor foobar(const at::Tensor& x) {
231            return x.clone();
232        }
233        """
234        module = torch.utils.cpp_extension.load_inline(
235            name="mylib",
236            cpp_sources=cpp_source,
237            functions="foobar",
238            verbose=True,
239        )
240
241        x = torch.ones(2, 2, requires_grad=True)
242        counters.clear()
243
244        @torch.compile(backend="eager")
245        def f(x):
246            return module.foobar(x)
247
248        with self.assertWarnsOnceRegex(
249            UserWarning,
250            ".*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*",
251        ):
252            f(x)
253        self.assertEqual(len(counters["graph_break"]), 1)
254        first_graph_break = list(counters["graph_break"].keys())[0]
255        self.assertExpectedInline(
256            first_graph_break,
257            """Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""",
258        )
259
260        cpp_source = """
261        #include <torch/extension.h>
262        at::Tensor baz(const at::Tensor& x) {
263            return x.clone();
264        }
265        """
266        module2 = torch.utils.cpp_extension.load_inline(
267            name="mylib2",
268            cpp_sources=cpp_source,
269            functions="baz",
270            verbose=True,
271        )
272
273        torch._dynamo.reset()
274
275        # Test that each warning only happens once
276        @torch.compile(backend="eager")
277        def f(x):
278            module2.baz(x)
279            module.foobar(x)
280            module.foobar(x)
281            module2.baz(x)
282            module.foobar(x)
283            module2.baz(x)
284            return x.clone()
285
286        with warnings.catch_warnings(record=True) as ws:
287            warnings.simplefilter("always")
288            f(x)
289            f(x)
290        self.assertEqual(len(ws), 2)
291
292    def test_callpacked(self):
293        def call_packed(args):
294            a, b, c = args
295            return a - b * c
296
297        counter = CompileCounter()
298        a = torch.randn(10, 10)
299        b = torch.randn(10, 10)
300        c = torch.randn(10, 10)
301        correct = call_packed([a, b, c])
302        opt_call_packed = torch._dynamo.optimize_assert(counter)(call_packed)
303        val1 = opt_call_packed([a, b, c])
304        val2 = opt_call_packed((a, b, c))
305        val3 = opt_call_packed([a, b, c])
306        val4 = opt_call_packed((a, b, c))
307        self.assertTrue(same(val1, correct))
308        self.assertTrue(same(val2, correct))
309        self.assertTrue(same(val3, correct))
310        self.assertTrue(same(val4, correct))
311        self.assertEqual(counter.frame_count, 2)
312
313    def test_raises(self):
314        def fn(a, b, c, cls):
315            x = a + b - c * 10
316            raise cls(str(x))
317
318        counter = CompileCounter()
319        a = torch.randn(10, 10)
320        b = torch.randn(10, 10)
321        c = torch.randn(10, 10)
322        opt_fn = torch._dynamo.optimize(counter)(fn)
323        self.assertRaises(AssertionError, lambda: opt_fn(a, b, c, AssertionError))
324        self.assertEqual(counter.frame_count, 1)
325        self.assertEqual(counter.op_count, 3)
326
327    def test_module_not_callable(self):
328        def fn(x):
329            return torch.fft(x)
330
331        counter = CompileCounter()
332        a = torch.randn(10, 10)
333        opt_fn = torch._dynamo.optimize(counter)(fn)
334        self.assertRaisesRegex(
335            TypeError, "'module' object is not callable", lambda: opt_fn(a)
336        )
337
338    def test_inplace(self):
339        def inplace1(a, b):
340            o = torch.empty((10, 10))
341            o.copy_(a)
342            o -= b
343            return o
344
345        torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3)
346
347    def test_inplace_desugaring(self):
348        def inplace_on_literals(y):
349            x0 = 1
350            x0 += y
351            x1 = 1
352            x1 -= y
353            return x0, x1
354
355        torch._dynamo.testing.standard_test(
356            self, inplace_on_literals, 1, expected_ops=2
357        )
358
359    def test_unpack4(self):
360        def unpack4(a, b):
361            a = a[:5, :]
362            b = b[:5, :]
363            x, y = a.size()
364            o = torch.empty((x, y))
365            o.copy_(a / b)
366            return o
367
368        torch._dynamo.testing.standard_test(
369            self,
370            unpack4,
371            2,
372            expected_ops=5,
373            expected_ops_dynamic=ifdynstaticdefault(5, 7),
374        )
375
376    def test_unpack5(self):
377        def unpack5(a, b):
378            a = a[:5, :]
379            b = b[:5, :]
380            x, y = a.shape
381            o = torch.empty((x, y))
382            o.copy_(a / b)
383            return o
384
385        torch._dynamo.testing.standard_test(
386            self,
387            unpack5,
388            2,
389            expected_ops=5,
390            expected_ops_dynamic=ifdynstaticdefault(5, 7),
391        )
392
393    def test_matmul1(self):
394        def matmul_op1(a, b):
395            return a @ b
396
397        # TODO(jansel): FX doesn't support this, should add upstream support
398        torch._dynamo.testing.standard_test(self, matmul_op1, 2, expected_ops=1)
399
400    def test_int_shape_binops(self):
401        def fn(x):
402            # Test reversal by putting int arg first.
403            y = 15 - x.shape[0]
404            y = 4 + y
405            y = 5 * y
406            y = 2 % y
407            y = 3**y
408            y = 10 // y
409            y = pow(2, y)
410            y = 10 / y
411            return x + y
412
413        torch._dynamo.testing.standard_test(
414            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 11)
415        )
416
417    @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
418    def test_pt2_compliant_ops_are_allowed(self):
419        lib = torch.library.Library("mylib", "FRAGMENT")
420        try:
421            torch.library.define(
422                "mylib::bar",
423                "(Tensor x) -> Tensor",
424                lib=lib,
425                tags=(torch.Tag.pt2_compliant_tag,),
426            )
427            torch.library.impl(
428                "mylib::bar", "CompositeImplicitAutograd", torch.sin, lib=lib
429            )
430            assert torch.Tag.pt2_compliant_tag in torch.ops.mylib.bar.default.tags
431
432            def f(x):
433                return torch.ops.mylib.bar(x)
434
435            overload = torch.ops.mylib.bar.default
436
437            def g(x):
438                return overload(x)
439
440            x = torch.randn(3)
441
442            counts = torch._dynamo.testing.CompileCounter()
443            optimized_f = torch._dynamo.optimize(counts, nopython=True)(f)
444            _ = optimized_f(x)
445
446            optimized_g = torch._dynamo.optimize(counts, nopython=True)(f)
447            _ = optimized_g(x)
448        finally:
449            cleanup_op("mylib::bar")
450            del lib
451
452    @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
453    def test_non_pt2_compliant_ops_graph_break(self):
454        lib = torch.library.Library("mylib", "FRAGMENT")
455        try:
456            torch.library.define("mylib::bar2", "(Tensor x) -> Tensor", lib=lib)
457            torch.library.impl(
458                "mylib::bar2", "CompositeImplicitAutograd", torch.sin, lib=lib
459            )
460            assert torch.Tag.pt2_compliant_tag not in torch.ops.mylib.bar2.default.tags
461
462            def f(x):
463                return torch.ops.mylib.bar2(x)
464
465            overload = torch.ops.mylib.bar2.default
466
467            def g(x):
468                return overload(x)
469
470            x = torch.randn(3)
471
472            counts = torch._dynamo.testing.CompileCounter()
473            with self.assertRaisesRegex(
474                torch._dynamo.exc.Unsupported, "not PT2 compliant"
475            ):
476                optimized_f = torch._dynamo.optimize(counts, nopython=True)(f)
477                y = optimized_f(x)
478
479            with self.assertRaisesRegex(
480                torch._dynamo.exc.Unsupported, "not PT2 compliant"
481            ):
482                optimized_g = torch._dynamo.optimize(counts, nopython=True)(f)
483                y = optimized_g(x)
484        finally:
485            cleanup_op("mylib::bar2")
486            del lib
487
488    @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
489    def test_pt2_compliant_overload(self):
490        lib = torch.library.Library("mylib", "FRAGMENT")
491        try:
492            torch.library.define(
493                "mylib::bar3.tensor",
494                "(Tensor x) -> Tensor",
495                tags=torch.Tag.pt2_compliant_tag,
496                lib=lib,
497            )
498            torch.library.define(
499                "mylib::bar3.int", "(Tensor x, int dim) -> Tensor", lib=lib
500            )
501
502            torch.library.impl(
503                "mylib::bar3.tensor",
504                "CompositeImplicitAutograd",
505                torch.sin,
506                lib=lib,
507            )
508            torch.library.impl(
509                "mylib::bar3.int", "CompositeImplicitAutograd", torch.sum, lib=lib
510            )
511
512            def f(x):
513                return torch.ops.mylib.bar3(x)
514
515            def g(x):
516                return torch.ops.mylib.bar3(x, 1)
517
518            def h(x):
519                return torch.ops.mylib.bar3(x, x, x)
520
521            x = torch.randn(3)
522
523            counts = torch._dynamo.testing.CompileCounter()
524            optimized_f = torch._dynamo.optimize(counts, nopython=True)(f)
525            optimized_g = torch._dynamo.optimize(counts, nopython=True)(g)
526            optimized_h = torch._dynamo.optimize(counts, nopython=True)(h)
527
528            # No error: the overload is PT2 compliant
529            optimized_f(x)
530
531            with self.assertRaisesRegex(
532                torch._dynamo.exc.Unsupported, "not PT2 compliant"
533            ):
534                y = optimized_g(x)
535
536            # graph break on incorrect parsing
537            with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "failed to"):
538                y = optimized_h(x)
539
540        finally:
541            cleanup_op("mylib::bar3")
542            del lib
543
544    def test_auto_functionalize_can_with_default(self):
545        lib = torch.library.Library("mylib", "FRAGMENT")
546        torch.library.define(
547            "mylib::foo",
548            "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()",
549            tags=torch.Tag.pt2_compliant_tag,
550            lib=lib,
551        )
552
553        @torch.library.impl("mylib::foo", "cpu", lib=lib)
554        def foo_impl(a, b, c=None, d=None, e=-1):
555            a + b
556            return
557
558        def f(a, mode):
559            return torch.ops.mylib.foo(
560                a,
561                0,
562            )
563
564        a = torch.tensor([10, 10, 10], dtype=torch.int64)
565
566        torch.compile(f)(a, 0)
567
568        cleanup_op("mylib::foo")
569        del lib
570
571    def test_user_defined_setattr1(self):
572        @torch.compile(backend="eager", fullgraph=True)
573        def fn(obj):
574            obj.y = obj.x + 1
575
576        obj = UserDefineSetAttr()
577        with patch.object(UserDefineSetAttr, "setup", True):
578            obj.x = torch.randn(8)
579        fn(obj)
580        with patch.object(UserDefineSetAttr, "setup", True):
581            self.assertEqual(obj.y, obj.x + 1)
582        self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"})
583
584    def test_user_defined_setattr2(self):
585        @torch.compile(backend="eager", fullgraph=True)
586        def fn(x):
587            obj = UserDefineSetAttr()
588            obj.x = x
589            obj.y = obj.x + 1
590            return obj
591
592        x = torch.randn(8)
593        obj = fn(x)
594        with patch.object(UserDefineSetAttr, "setup", True):
595            self.assertIs(obj.x, x)
596            self.assertEqual(obj.y, x + 1)
597        self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"})
598
599    def test_closure_recompiles(self):
600        cnt = CompileCounter()
601
602        def fn(x, other_fn):
603            return other_fn(x + 1) - 1
604
605        opt = torch.compile(fn, backend=cnt, fullgraph=True)
606
607        x = torch.randn(8)
608        for f in (
609            closure_adder(5),
610            closure_adder(5),
611            closure_adder(torch.randn(8)),
612            closure_adder(torch.randn(8)),
613        ):
614            self.assertEqual(opt(x, f), fn(x, f))
615
616        self.assertEqual(cnt.frame_count, 2)
617
618    def test_generate_trivial_abstract_impl(self):
619        try:
620            lib = torch.library.Library("mylib", "FRAGMENT")
621            torch.library.define(
622                "mylib::foo",
623                "(Tensor x, Tensor[] y, Tensor(a!)? z, SymInt w) -> ()",
624                tags=torch.Tag.pt2_compliant_tag,
625                lib=lib,
626            )
627
628            @torch.library.impl("mylib::foo", "cpu", lib=lib)
629            @torch._dynamo.disable
630            def foo_impl(x, y, z, w):
631                x + y[0] + w
632                return
633
634            def f(x, y, z, w):
635                return torch.ops.mylib.foo(x, y, z, 2)
636
637            x = torch.randn(3)
638            y = (torch.randn(3), torch.randn(3))
639            z = torch.randn(3)
640            w = torch.randn(3)
641            args = (x, y, z, w)
642
643            output = torch.compile(f, backend="eager", fullgraph=True)(*args)
644            self.assertEqual(output, None)
645        finally:
646            cleanup_op("mylib::foo")
647            del lib
648
649    def test_can_auto_functionalize(self):
650        from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
651
652        expected_true = [
653            "(Tensor(a!) x) -> ()",
654            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
655            "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
656            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor",
657            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)",
658        ]
659        expected_false = [
660            "(Tensor x) -> ()",
661            "(Tensor(a) x) -> Tensor(a)",
662            "(Tensor(a!) x) -> Tensor(a!)",
663            "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
664            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)",
665            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
666            "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
667            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])",
668        ]
669        for schema in expected_true:
670            try:
671                lib = torch.library.Library("mylib", "FRAGMENT")
672                torch.library.define("mylib::a", schema, lib=lib)
673                self.assertTrue(
674                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
675                )
676                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
677            finally:
678                cleanup_op("mylib::a")
679                del lib
680        for schema in expected_false:
681            try:
682                lib = torch.library.Library("mylib", "FRAGMENT")
683                torch.library.define("mylib::a", schema, lib=lib)
684                self.assertFalse(
685                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
686                )
687                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
688            finally:
689                cleanup_op("mylib::a")
690                del lib
691
692    def test_auto_functionalize(self):
693        try:
694            lib = torch.library.Library("mylib", "FRAGMENT")
695            torch.library.define(
696                "mylib::foo",
697                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
698                tags=torch.Tag.pt2_compliant_tag,
699                lib=lib,
700            )
701
702            @torch.library.impl("mylib::foo", "cpu", lib=lib)
703            @torch._dynamo.disable
704            def foo_impl(x, y, z, w, n):
705                x.add_(y[0] + w)
706                z.add_(y[1] + n)
707
708            def f(x, y, z, n):
709                torch.ops.mylib.foo(x, y, z, 2, n)
710
711            x = torch.randn(3)
712            y = (torch.randn(3), torch.randn(3))
713            z = torch.randn(3)
714            n = torch.randn(3)
715            orig_args = (x, y, z, n)
716
717            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
718
719            log_stream, ctx = logs_to_string(
720                "torch._inductor.compile_fx", "post_grad_graphs"
721            )
722            with ctx():
723                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
724
725            post_grad_graphs = "\n".join(
726                log_stream.getvalue().strip().split("\n")[3:]
727            ).strip()
728
729            # Check the graph under static shapes
730            if torch._dynamo.config.assume_static_by_default:
731                self.assertExpectedInline(
732                    post_grad_graphs,
733                    """\
734def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
735        # No stacktrace found for following nodes
736        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
737        return ()""",
738                )
739
740            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
741            f(*eager_args)
742            self.assertEqual(compiled_args, eager_args)
743        finally:
744            cleanup_op("mylib::foo")
745            del lib
746
747    def test_auto_functionalize_with_returns(self):
748        try:
749            lib = torch.library.Library("mylib", "FRAGMENT")
750            torch.library.define(
751                "mylib::foo",
752                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
753                tags=torch.Tag.pt2_compliant_tag,
754                lib=lib,
755            )
756
757            @torch.library.impl("mylib::foo", "cpu", lib=lib)
758            @torch._dynamo.disable
759            def foo_impl(x, y, z, w, n):
760                x.add_(y[0] + w)
761                z.add_(y[1] + n)
762                return y[0] + w, y[1] + n
763
764            @torch.library.impl_abstract("mylib::foo", lib=lib)
765            def foo_abstract(x, y, z, w, n):
766                return y[0] + w, y[1] + n
767
768            def f(x, y, z, n):
769                return torch.ops.mylib.foo(x, y, z, 2, n)
770
771            x = torch.randn(3)
772            y = (torch.randn(3), torch.randn(3))
773            z = torch.randn(3)
774            n = torch.randn(3)
775            orig_args = (x, y, z, n)
776
777            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
778            log_stream, ctx = logs_to_string(
779                "torch._inductor.compile_fx", "post_grad_graphs"
780            )
781            with ctx():
782                compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
783                    *compiled_args
784                )
785
786            if torch._dynamo.config.assume_static_by_default:
787                post_grad_graphs = "\n".join(
788                    log_stream.getvalue().strip().split("\n")[3:]
789                ).strip()
790                self.assertExpectedInline(
791                    post_grad_graphs,
792                    """\
793def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
794        # No stacktrace found for following nodes
795        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
796        getitem_4: "f32[3][1]cpu" = foo_default[0]
797        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
798        return (getitem_4, getitem_5)""",
799                )
800
801            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
802            eager_out = f(*eager_args)
803            self.assertEqual(compiled_args, eager_args)
804            self.assertEqual(compiled_out, eager_out)
805        finally:
806            cleanup_op("mylib::foo")
807            del lib
808
809    def test_auto_functionalize_on_view(self):
810        try:
811            lib = torch.library.Library("mylib", "FRAGMENT")
812            torch.library.define(
813                "mylib::foo",
814                "(Tensor(a!) x) -> ()",
815                tags=torch.Tag.pt2_compliant_tag,
816                lib=lib,
817            )
818
819            @torch.library.impl("mylib::foo", "cpu", lib=lib)
820            @torch._dynamo.disable
821            def foo_impl(x):
822                x_np = x.detach().numpy()  # view
823                np.sin(x_np, out=x_np)
824                return
825
826            x = torch.randn(3)
827            expected = x.sin()
828            torch.ops.mylib.foo(x)
829            assert torch.allclose(x, expected)
830
831            @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
832            def f(x):
833                x = x.clone()
834                y = x[:]
835                torch.ops.mylib.foo(y)
836                return x
837
838            y = f(x)
839            self.assertEqual(y, x.sin())
840        finally:
841            cleanup_op("mylib::foo")
842            del lib
843
844    def test_auto_functionalize_optional(self):
845        try:
846            lib = torch.library.Library("mylib", "FRAGMENT")
847            torch.library.define(
848                "mylib::foo",
849                "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
850                tags=torch.Tag.pt2_compliant_tag,
851                lib=lib,
852            )
853
854            @torch.library.impl("mylib::foo", "cpu", lib=lib)
855            @torch._dynamo.disable
856            def foo_impl(x, y, z, w, n):
857                if x is not None:
858                    x.add_(y[0] + w)
859                if z is not None:
860                    z.add_(y[1] + n)
861
862            def f(x, y, z, n):
863                torch.ops.mylib.foo(x, y, z, 2, n)
864
865            x = None
866            y = (torch.randn(3), torch.randn(3))
867            z = torch.randn(3)
868            n = torch.randn(3)
869            orig_args = (x, y, z, n)
870
871            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
872            log_stream, ctx = logs_to_string(
873                "torch._inductor.compile_fx", "post_grad_graphs"
874            )
875            with ctx():
876                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
877
878            if torch._dynamo.config.assume_static_by_default:
879                post_grad_graphs = "\n".join(
880                    log_stream.getvalue().strip().split("\n")[3:]
881                ).strip()
882                self.assertExpectedInline(
883                    post_grad_graphs,
884                    """\
885def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
886        # No stacktrace found for following nodes
887        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
888        return ()""",
889                )
890
891            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
892            f(*eager_args)
893            self.assertEqual(compiled_args, eager_args)
894        finally:
895            cleanup_op("mylib::foo")
896            del lib
897
898    def test_shape_int_inplace_binops(self):
899        def fn(x):
900            p = x.shape[0]
901            p += 2
902            p -= 2
903            p **= 2
904            p /= 2
905            p *= 2
906            p //= 2
907            p %= 2
908            return x + p
909
910        torch._dynamo.testing.standard_test(
911            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 10)
912        )
913
914    def test_int_shape_inplace_binops(self):
915        def fn(x):
916            p = x.shape[0]
917            # Test reversal by putting constant first
918            y = 2
919            y += p
920            y = 2
921            y -= p
922            y = 2
923            y **= p
924            y = 2
925            y /= p
926            y = 2
927            y *= p
928            y = 2
929            y //= p
930            y = 2
931            y %= p
932            return x + y
933
934        torch._dynamo.testing.standard_test(
935            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 4)
936        )
937
938    def test_int_int_comparisons(self):
939        def fn(x):
940            if 2 != 2:
941                out = 1
942            elif 2 < 1:
943                out = 1
944            elif 1 > 2:
945                out = 1
946            elif 1 >= 2:
947                out = 1
948            elif 2 <= 1:
949                out = 1
950            elif 2 == 2:
951                out = 2
952            else:
953                out = 1
954            return x + out
955
956        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
957
958    def test_shape_int_comparisons(self):
959        def fn(x):
960            a = x.shape[0]
961            # Ensure support for constant on right side
962            if a != 10:
963                out = 1
964            elif a < 2:
965                out = 1
966            elif a > 12:
967                out = 1
968            elif a >= 12:
969                out = 1
970            elif a <= 2:
971                out = 1
972            elif a == 10:
973                out = 2
974            else:
975                out = 1
976            return x + out
977
978        # TODO: Test the guards maybe?
979        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
980
981    def test_int_shape_comparisons(self):
982        def fn(x):
983            a = x.shape[0]
984            # Ensure support for constant on left side
985            if 10 != a:
986                out = 1
987            elif 12 < a:
988                out = 1
989            elif 2 > a:
990                out = 1
991            elif 2 >= a:
992                out = 1
993            elif 12 <= a:
994                out = 1
995            elif 10 == a:
996                out = 2
997            else:
998                out = 1
999            return x + out
1000
1001        # TODO: Test the guards maybe?
1002        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
1003
1004    def test_param_shape_binops(self):
1005        class MyModule(torch.nn.Module):
1006            def __init__(self):
1007                super().__init__()
1008                self.param = torch.nn.Parameter(torch.randn(15))
1009
1010            def forward(self, x):
1011                # Test reversal by putting param shape arg first.
1012                p = self.param.shape[0]
1013                y = p - x.shape[0]
1014                y = p + y
1015                y = p * y
1016                y = p % y
1017                y = p**y
1018                y = p // y
1019                y = pow(p, y)
1020                y = p / y
1021                return x + y
1022
1023        counts = torch._dynamo.testing.CompileCounter()
1024        mod = MyModule()
1025        optimized_mod = torch._dynamo.optimize(counts, nopython=True)(mod)
1026
1027        x = torch.randn(3)
1028        ref = mod(x)
1029        res = optimized_mod(x)
1030
1031        self.assertTrue(same(ref, res))
1032        self.assertEqual(counts.frame_count, 1)
1033
1034        if torch._dynamo.config.assume_static_by_default:
1035            self.assertExpectedInline(counts.op_count, """1""")
1036        else:
1037            self.assertExpectedInline(counts.op_count, """11""")
1038
1039    def test_user_defined_binop(self):
1040        class MyClass:
1041            def __init__(self, value):
1042                self.value = value
1043
1044            def __radd__(self, other):
1045                return self.value + other
1046
1047        def fn(x, c):
1048            y = x.shape[0] + c
1049            return x + y
1050
1051        counts = torch._dynamo.testing.CompileCounter()
1052        opt_fn = torch._dynamo.optimize(counts)(fn)
1053
1054        x = torch.randn(3)
1055        c = MyClass(4)
1056        ref = fn(x, c)
1057        res = opt_fn(x, c)
1058
1059        self.assertTrue(same(ref, res))
1060        self.assertEqual(counts.frame_count, 1)
1061        if torch._dynamo.config.assume_static_by_default:
1062            self.assertExpectedInline(counts.op_count, """1""")
1063        else:
1064            self.assertExpectedInline(counts.op_count, """4""")
1065
1066    def test_user_defined_iter(self):
1067        class Mod:
1068            def __init__(self):
1069                self.a = [torch.randn(2, 2), torch.randn(2, 2)]
1070
1071            def __iter__(self):
1072                return iter(self.a)
1073
1074        def f(mod):
1075            ret = []
1076            for x in mod:
1077                ret.append(x + 1)
1078            return ret
1079
1080        mod = Mod()
1081        counts = torch._dynamo.testing.CompileCounter()
1082        opt_fn = torch._dynamo.optimize(counts, nopython=True)(f)
1083        ref = f(mod)
1084        res = opt_fn(mod)
1085        res = opt_fn(mod)
1086        res = opt_fn(mod)
1087        res = opt_fn(mod)
1088        self.assertTrue(same(ref, res))
1089        self.assertEqual(counts.frame_count, 1)
1090
1091        mod.a.append(torch.randn(2, 2))
1092        # `for x in mod` is inlined, where iter(m.a) creates a guard on the list length of m.a
1093        # Mutating length of mod.a causes a re-compilation.
1094        ref2 = f(mod)
1095        res2 = opt_fn(mod)
1096        res2 = opt_fn(mod)
1097        res2 = opt_fn(mod)
1098        res2 = opt_fn(mod)
1099        self.assertTrue(same(ref2, res2))
1100        self.assertEqual(counts.frame_count, 2)
1101
1102    def test_compare_shapes_eq(self):
1103        def compare_shapes(a, b, to_list):
1104            x = list(a.unsqueeze(-1).shape) if to_list else a.shape
1105            y = list(b.unsqueeze(-1).shape) if to_list else b.shape
1106            if x == y:
1107                return a + 1
1108            else:
1109                return a + 2
1110
1111        # Test both ListVariable and ShapeVariable
1112        torch._dynamo.testing.standard_test(
1113            self, lambda a, b: compare_shapes(a, b, to_list=True), 2
1114        )
1115        torch._dynamo.testing.standard_test(
1116            self, lambda a, b: compare_shapes(a, b, to_list=False), 2
1117        )
1118
1119    def test_compare_shapes_tuple_eq(self):
1120        def compare_shapes(a, b):
1121            x = tuple(a.unsqueeze(-1).shape)
1122            y = tuple(b.unsqueeze(-1).shape)
1123            if x == y:
1124                return a + 1
1125            else:
1126                return a + 2
1127
1128        torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2)
1129
1130    def test_compare_shapes_tuple_neq(self):
1131        def compare_shapes(a, b):
1132            x = tuple(a.unsqueeze(-1).shape)
1133            y = tuple(b.unsqueeze(-1).shape)
1134            if x != y:
1135                return a + 1
1136            else:
1137                return a + 2
1138
1139        torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2)
1140
1141    def test_compare_shapes_neq(self):
1142        def compare_shapes(a, b, to_list):
1143            x = list(a.unsqueeze(-1).shape) if to_list else a.shape
1144            y = list(b.unsqueeze(-1).shape) if to_list else b.shape
1145            if x != y:
1146                return a + 1
1147            else:
1148                return a + 2
1149
1150        # Test both ListVariable and ShapeVariable
1151        torch._dynamo.testing.standard_test(
1152            self, lambda a, b: compare_shapes(a, b, to_list=True), 2
1153        )
1154        torch._dynamo.testing.standard_test(
1155            self, lambda a, b: compare_shapes(a, b, to_list=False), 2
1156        )
1157
1158    def test_compare_shapes_with_constant(self):
1159        def compare_shapes(a):
1160            x = a.shape
1161            if x[0] != 3:
1162                return a * 4
1163            return a * 3
1164
1165        guard_failure = None
1166
1167        def guard_failures(failure):
1168            nonlocal guard_failure
1169            guard_failure = failure
1170
1171        opt_fn = torch._dynamo.optimize(
1172            "eager", nopython=True, guard_fail_fn=guard_failures
1173        )(compare_shapes)
1174        opt_fn(torch.randn([3, 4]))
1175        opt_fn(torch.randn([4, 3]))
1176        self.assertIn(
1177            """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
1178            guard_failure.reason,
1179        )
1180
1181    def test_builtin_abs(self):
1182        def fn(x, y):
1183            return abs(x) + abs(y)
1184
1185        sample = torch.randn(10, 10)
1186        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
1187
1188        for sample in [
1189            (torch.randn(10, 10), torch.randn(10, 10)),
1190            (-10, make_tensor(10, dtype=torch.int64, device="cpu")),
1191            (-0.1, torch.randn(10)),
1192        ]:
1193            expect = fn(*sample)
1194            actual = opt_fn(*sample)
1195            self.assertEqual(expect, actual)
1196
1197    def test_builtin_isinstance(self):
1198        def fn(x):
1199            t = torch.arange(1, 3)
1200            a = isinstance(x, torch.Tensor)
1201            b = isinstance(t, torch.Tensor)
1202            c = isinstance(x, int)
1203            d = isinstance(3, int)
1204            e = isinstance([1, 2, 3], list)
1205            f = isinstance({"foo": 1, "bar": 2}, dict)
1206            res = [a, b, c, d, e, f]
1207            # Can't run yet due to other unimplemented instructions
1208            # res += [isinstance(torch.nn.LazyLinear(2, 3), torch.nn.Linear)]
1209            return res
1210
1211        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
1212
1213    @unittest.skipIf(sys.version_info[:2] <= (3, 8), "Requires astunparse")
1214    def test_cse_dict_guards(self):
1215        def fn(x):
1216            ret = torch.zeros(3)
1217            for v in x.values():
1218                ret = ret + v
1219            return ret
1220
1221        from torch._dynamo.guards import build_guard_function, CLOSURE_VARS
1222
1223        x = {3: torch.randn(3), 2: torch.randn(3), 4: torch.randn(3)}
1224        _, guards = torch._dynamo.export(fn, x)
1225
1226        code_lists = [c for g in guards for c in g.code_list or []]
1227        _, pycode = build_guard_function(code_lists, [])
1228        # Make sure we just call "list(dict.keys())" once
1229        self.assertEqual(pycode.count("keys"), 1)
1230
1231    def test_sys_modules(self):
1232        def fn(x, y):
1233            mod_a = sys.modules.get("aaaaaaaa")
1234            assert mod_a is None
1235            assert "bbbbbbbb" not in sys.modules
1236
1237            assert "operator" in sys.modules
1238            operator = sys.modules["operator"]
1239            builtins = sys.modules.get("builtins")
1240            operator2 = sys.modules.get("cccccccc", operator)
1241
1242            return operator.add(x, y), operator2.neg(builtins.abs(x))
1243
1244        torch._dynamo.testing.standard_test(self, fn, 2, expected_ops=3)
1245
1246        x = torch.randn(10, 10)
1247        _, guards = torch._dynamo.export(fn, x, x)
1248        guard_code = []
1249        for guard in guards:
1250            if guard.code_list:
1251                guard_code += guard.code_list
1252
1253        # Filter out id-matches that won't reproduce run to run
1254        guard_code = filter(
1255            lambda line: "id" not in line and "lookup_backend" not in line,
1256            sorted(guard_code),
1257        )
1258        guard_code_str = "\n".join(guard_code)
1259
1260        for line in """\
12612 <= L['x'].size()[0]
1262L['x'] is L['y']
1263L['x'].ndimension() == 2
1264L['x'].requires_grad == False
1265L['x'].size()[1] == L['x'].size()[0]
1266L['x'].storage_offset() == 0
1267___dict_contains('builtins', G['sys'].modules)
1268___dict_contains('operator', G['sys'].modules)
1269___dict_contains('operator', G['sys'].modules)
1270hasattr(L['x'], '_dynamo_dynamic_indices') == False
1271not ___dict_contains('aaaaaaaa', G['sys'].modules)
1272not ___dict_contains('bbbbbbbb', G['sys'].modules)
1273not ___dict_contains('cccccccc', G['sys'].modules)
1274str(L['x'].device) == 'cpu'
1275str(L['x'].dtype) == 'torch.float32'
1276utils_device.CURRENT_DEVICE == None""".split(
1277            "\n"
1278        ):
1279            self.assertIn(
1280                line,
1281                guard_code_str,
1282            )
1283
1284    def test_fold(self):
1285        def fn(a):
1286            return a + math.sqrt(63)
1287
1288        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
1289
1290    def test_getattr_dict(self):
1291        def fn(x):
1292            from torch.masked.maskedtensor._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
1293
1294            return x * len(_MASKEDTENSOR_FUNCTION_TABLE)
1295
1296        i = torch.randn(5)
1297        r1 = fn(i)
1298        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1299        r2 = opt_fn(i)
1300        self.assertEqual(r1, r2)
1301
1302    def test_shape_unpack(self):
1303        def fn(x):
1304            a, b = x.size()
1305            return x * b
1306
1307        i = torch.randn(5, 10)
1308        r1 = fn(i)
1309        opt_fn = torch._dynamo.optimize("eager")(fn)
1310        r2 = opt_fn(i)
1311        self.assertTrue(same(r1, r2))
1312
1313    def test_typing_dict(self):
1314        def fn(d):
1315            return d[T]
1316
1317        d = {T: torch.randn(3)}
1318        r1 = fn(d)
1319        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1320        r2 = opt_fn(d)
1321        self.assertEqual(r1, r2)
1322
1323    def test_tensor_iter(self):
1324        def fn(x):
1325            for y in x:
1326                y.add_(1.0)
1327            return y
1328
1329        torch._dynamo.testing.standard_test(
1330            self,
1331            fn,
1332            1,
1333            expected_ops=20,
1334        )
1335
1336    def test_empty_list(self):
1337        def fn(x, ll):
1338            if len(ll) == 0 and not ll and ll is not None:
1339                return x + 1
1340
1341        i = torch.randn(5, 10)
1342        r1 = fn(i, [])
1343        opt_fn = torch._dynamo.optimize("eager")(fn)
1344        r2 = opt_fn(i, [])
1345        r3 = opt_fn(i, tuple())
1346        self.assertTrue(same(r1, r2))
1347        self.assertTrue(same(r1, r3))
1348
1349    def test_min_max_over_iterable(self):
1350        def get_test_fn(func):
1351            def _fn(a, b, func=func):
1352                # try all of list, iterator, tuple, vararg.
1353                lst = [a.shape[0] + 1, 8, a.shape[0]]
1354                x = func(lst)
1355                y = func(iter(lst))
1356                z = func(tuple(lst))
1357                w = func(*lst)
1358                return a + (x + y + z + w)
1359
1360            return _fn
1361
1362        torch._dynamo.testing.standard_test(
1363            self,
1364            get_test_fn(func=min),
1365            2,
1366            expected_ops=1,
1367            expected_ops_dynamic=ifdynstaticdefault(1, 14),
1368        )
1369        torch._dynamo.testing.standard_test(
1370            self,
1371            get_test_fn(func=max),
1372            2,
1373            expected_ops=1,
1374            expected_ops_dynamic=ifdynstaticdefault(1, 17),
1375        )
1376
1377    @torch._dynamo.config.patch(capture_scalar_outputs=True)
1378    def test_torch_check(self):
1379        cnts = torch._dynamo.testing.CompileCounter()
1380
1381        @torch.compile(backend=cnts, fullgraph=True)
1382        def f(x):
1383            y = x.item()
1384            torch._check(y >= 0)
1385            return torch.arange(0, y)
1386
1387        f(torch.tensor([3]))
1388        f(torch.tensor([4]))
1389        self.assertEqual(cnts.frame_count, 1)
1390
1391    @torch._dynamo.config.patch(capture_scalar_outputs=True)
1392    def test_torch_check_symbolic_shape_rel(self):
1393        cnts = torch._dynamo.testing.CompileCounter()
1394
1395        @torch.compile(backend=cnts, fullgraph=True)
1396        def f(x):
1397            y = x.item()
1398            torch._check(x.shape[0] == 1)
1399            torch._check(x.shape[0] != 2)
1400            torch._check(x.shape[0] >= 0)
1401            torch._check(x.shape[0] > 0)
1402            torch._check(x.shape[0] < 4)
1403            torch._check(x.shape[0] <= 3)
1404            return torch.arange(0, y)
1405
1406        f(torch.tensor([3]))
1407        f(torch.tensor([4]))
1408        self.assertEqual(cnts.frame_count, 1)
1409
1410    @torch._dynamo.config.patch(capture_scalar_outputs=True)
1411    # Translation validation changes the exception type, don't run with it
1412    @torch.fx.experimental._config.patch(translation_validation=False)
1413    def test_torch_check_is_size(self):
1414        cnts = torch._dynamo.testing.CompileCounter()
1415
1416        @torch.compile(backend=cnts, fullgraph=True)
1417        def f(x):
1418            y = x.item()
1419            torch._check_is_size(y)
1420            # Cannot conditional on unbacked SymInt
1421            if y == 0:
1422                assert False
1423            else:
1424                return torch.arange(0, y)
1425
1426        self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3])))
1427
1428    def test_assert(self):
1429        @torch.compile
1430        def fn1(x):
1431            assert x.shape != x.shape
1432
1433        with self.assertRaises(AssertionError):
1434            a = torch.randn(10)
1435            fn1(a)
1436
1437        def fn2(x):
1438            assert x.shape == x.shape
1439            return x.abs()
1440
1441        torch._dynamo.testing.standard_test(self, fn=fn2, nargs=1, expected_ops=1)
1442
1443    def test_config_obj(self):
1444        class Cfg:
1445            def __init__(self):
1446                self.val = 0.5
1447                self.count = 3
1448
1449        def fn(x, cfg):
1450            for i in range(cfg.count):
1451                x = x + cfg.val
1452            return x
1453
1454        cfg1 = Cfg()
1455        cfg1.val = 1.0
1456        cfg2 = Cfg()
1457        v = torch.zeros(1)
1458        cnts = torch._dynamo.testing.CompileCounter()
1459        opt_fn = torch._dynamo.optimize(cnts)(fn)
1460        v = opt_fn(v, cfg1)  # 3
1461        v = opt_fn(v, cfg2)  # 4.5
1462        cfg2.count = 1
1463        v = opt_fn(v, cfg2)  # 5
1464        cfg2.val = 2.0
1465        v = opt_fn(v, cfg2)  # 7
1466        self.assertEqual(v[0], 7)
1467        self.assertEqual(cnts.op_count, 8)
1468
1469    def test_config_getattr_default(self):
1470        class Cfg:
1471            def __init__(self):
1472                self.val = 0.5
1473                self.count = 10
1474
1475        def fn(x, cfg):
1476            if getattr(cfg, "just_add_7", False):
1477                return x + 7
1478            for i in range(cfg.count):
1479                x = x + cfg.val
1480            return x
1481
1482        cfg1 = Cfg()
1483        v = torch.zeros(1)
1484        cnts = torch._dynamo.testing.CompileCounter()
1485        opt_fn = torch._dynamo.optimize(cnts)(fn)
1486        self.assertEqual(opt_fn(v, cfg1)[0], 5)
1487        self.assertEqual(opt_fn(v, cfg1)[0], 5)
1488        cfg1.just_add_7 = True
1489        self.assertEqual(opt_fn(v, cfg1)[0], 7)
1490        self.assertEqual(opt_fn(v, cfg1)[0], 7)
1491        cfg1.just_add_7 = False
1492        self.assertEqual(opt_fn(v, cfg1)[0], 5)
1493        self.assertEqual(opt_fn(v, cfg1)[0], 5)
1494        self.assertEqual(cnts.frame_count, 3)
1495
1496    def test_size_input(self):
1497        def fn(x, s):
1498            a, b = s
1499            return x + (a - b)
1500
1501        v = torch.zeros(10, 20)
1502        cnts = torch._dynamo.testing.CompileCounter()
1503        opt_fn = torch._dynamo.optimize(cnts)(fn)
1504        self.assertEqual(opt_fn(v, v.size())[0, 0], -10)
1505        self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10)
1506        self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10)
1507        # One recompile per differing input type
1508        self.assertEqual(cnts.frame_count, 3)
1509
1510    def test_cell_output1(self):
1511        out = None
1512
1513        def fn(a, b):
1514            nonlocal out
1515            out = a + b * 10
1516
1517        v = torch.Tensor([100])
1518        cnts = torch._dynamo.testing.CompileCounter()
1519        opt_fn = torch._dynamo.optimize(cnts)(fn)
1520        self.assertIsNone(opt_fn(v, v))
1521        self.assertEqual(out[0], 1100)
1522        self.assertEqual(cnts.op_count, 2)
1523
1524    def test_cell_output2(self):
1525        out = None
1526
1527        def fn(a, b):
1528            nonlocal out
1529            c = unsupported(a, b)
1530            out = a + b * 10 + c
1531
1532        v = torch.Tensor([100])
1533        cnts = torch._dynamo.testing.CompileCounter()
1534        opt_fn = torch._dynamo.optimize(cnts)(fn)
1535        self.assertIsNone(opt_fn(v, v))
1536        self.assertEqual(out[0], 1200)
1537        self.assertEqual(cnts.op_count, 3)
1538
1539    def test_return_nested_function(self):
1540        out = None
1541
1542        def fn(a, b):
1543            nonlocal out
1544            c = a + b
1545            d = a + 1.0
1546
1547            def fn2(f: int = 7, g: float = 9.0):
1548                nonlocal out
1549                out = a + b * 10
1550                return c * f - d * g
1551
1552            return fn2
1553
1554        v1 = torch.Tensor([100])
1555        v2 = torch.Tensor([200])
1556        cnts = torch._dynamo.testing.CompileCounter()
1557        opt_fn = torch._dynamo.optimize(cnts)(fn)
1558        opt_fn_ret = torch._dynamo.optimize(cnts)(opt_fn(v1, v2))
1559        self.assertEqual(opt_fn_ret(1.5)[0], -459)
1560        self.assertEqual(out[0], 2100)
1561        self.assertEqual(cnts.frame_count, 2)
1562        self.assertEqual(cnts.op_count, 7)
1563
1564    def test_tensor_dict1(self):
1565        def fn(inputs):
1566            return inputs["a"] - inputs["b"] * 1.5
1567
1568        v1 = torch.Tensor([100])
1569        v2 = torch.Tensor([200])
1570        cnts = torch._dynamo.testing.CompileCounter()
1571        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
1572        self.assertEqual(opt_fn({"a": v1, "b": v2})[0], -200)
1573        self.assertEqual(cnts.frame_count, 1)
1574        self.assertEqual(cnts.op_count, 2)
1575
1576    def test_tensor_dict3(self):
1577        def fn(inputs_a, inputs_b):
1578            total = torch.zeros(1)
1579            input_keys = inputs_a.keys() | inputs_b.keys()
1580            for k in input_keys:
1581                if k in inputs_a:
1582                    total += inputs_a[k]
1583                if k in inputs_b:
1584                    total += inputs_b[k]
1585            return total
1586
1587        v1 = torch.Tensor([100])
1588        v2 = torch.Tensor([200])
1589        cnts = torch._dynamo.testing.CompileCounter()
1590        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
1591        self.assertEqual(
1592            opt_fn({"a": v1, "b": v2}, {"b": v1, "c": v2}),
1593            fn({"a": v1, "b": v2}, {"b": v1, "c": v2}),
1594        )
1595        self.assertEqual(cnts.frame_count, 1)
1596        self.assertEqual(cnts.op_count, 5)
1597
1598    def test_tensor_dict2(self):
1599        def fn1(inputs):
1600            total = torch.zeros(1)
1601            for k, v in inputs.items():
1602                total += v
1603            return total
1604
1605        def fn2(inputs):
1606            total = torch.zeros(1)
1607            for v in inputs.values():
1608                total += v
1609            return total
1610
1611        def fn3(inputs):
1612            total = torch.zeros(1)
1613            for k in inputs.keys():
1614                total += inputs[k]
1615            return total
1616
1617        v1 = torch.Tensor([100])
1618        v2 = torch.Tensor([200])
1619        cnts = torch._dynamo.testing.CompileCounter()
1620        opt_fn1 = torch._dynamo.optimize(cnts, nopython=True)(fn1)
1621        opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
1622        opt_fn3 = torch._dynamo.optimize(cnts, nopython=True)(fn3)
1623        self.assertEqual(opt_fn1({"a": v1, "b": v2})[0], 300)
1624        self.assertEqual(opt_fn2({"a": v1, "b": v2})[0], 300)
1625        self.assertEqual(opt_fn3({"a": v1, "b": v2})[0], 300)
1626        self.assertEqual(cnts.frame_count, 3)
1627        self.assertEqual(cnts.op_count, 9)
1628
1629    def test_dictcomp(self):
1630        def fn1(inputs):
1631            return {k: v + 1 for k, v in inputs.items()}
1632
1633        v1 = torch.Tensor([100])
1634        v2 = torch.Tensor([200])
1635        cnts = torch._dynamo.testing.CompileCounter()
1636        opt_fn1 = torch._dynamo.optimize(cnts)(fn1)
1637        self.assertEqual(opt_fn1({"a": v1, "b": v2})["a"], 101)
1638        self.assertEqual(opt_fn1({"a": v1, "b": v2})["b"], 201)
1639        self.assertEqual(cnts.frame_count, 1)
1640        self.assertEqual(cnts.op_count, 2)
1641
1642    def test_listcomp(self):
1643        def fn2(inputs):
1644            return torch.sum(torch.cat([v + 1 for k, v in inputs.items()], 0))
1645
1646        v1 = torch.Tensor([100])
1647        v2 = torch.Tensor([200])
1648        cnts = torch._dynamo.testing.CompileCounter()
1649        opt_fn2 = torch._dynamo.optimize(cnts)(fn2)
1650        self.assertEqual(opt_fn2({"a": v1, "b": v2}), 302)
1651        self.assertEqual(cnts.frame_count, 1)
1652        self.assertEqual(cnts.op_count, 4)
1653
1654    def test_is_floating_point(self):
1655        def fn(a, b):
1656            x = a + 1.0
1657            if torch.is_floating_point(b):
1658                x = x + b
1659            return x + 2.0
1660
1661        return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
1662
1663    def test_is_floating_point2(self):
1664        def fn(a, b):
1665            x = a + 1.0
1666            if b.is_floating_point():
1667                x = x + b
1668            return x + 2.0
1669
1670        return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
1671
1672    def test_is_tensor(self):
1673        def fn(a, b):
1674            x = a + 1.0
1675            if torch.is_tensor(b):
1676                x = x + b
1677            return x + 2.0
1678
1679        return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
1680
1681    def test_is_tensor2(self):
1682        def fn(x):
1683            if torch.is_tensor(x):
1684                return x + 1
1685            else:
1686                return torch.ones([2, 3])
1687
1688        x1 = {"input": torch.rand(2, 3)}
1689        x2 = torch.rand(2, 3)
1690        ref1 = fn(x1)
1691        ref2 = fn(x2)
1692        opt_fn = torch._dynamo.optimize("eager")(fn)
1693        res1 = opt_fn(x1)
1694        res2 = opt_fn(x2)
1695        self.assertEqual(ref1, res1)
1696        self.assertEqual(ref2, res2)
1697
1698    def test_numel(self):
1699        def fn(a):
1700            return (a + a.numel() + torch.numel(a), a + a.nelement())
1701
1702        return torch._dynamo.testing.standard_test(
1703            self,
1704            fn=fn,
1705            nargs=1,
1706            expected_ops=3,
1707            expected_ops_dynamic=ifdynstaticdefault(3, 6),
1708        )
1709
1710    def test_pair(self):
1711        def fn(a):
1712            return (
1713                torch.zeros(torch.nn.modules.utils._pair(a.size()))
1714                + a
1715                + torch.ones(torch.nn.modules.utils._ntuple(3)(3)).sum()
1716            )
1717
1718        return torch._dynamo.testing.standard_test(
1719            self,
1720            fn=fn,
1721            nargs=1,
1722            expected_ops=5,
1723            expected_ops_dynamic=ifdynstaticdefault(5, 8),
1724        )
1725
1726    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
1727    def test_tensor_item_capture(self):
1728        def fn(a, b):
1729            return (a + b).sum().item()
1730
1731        v1 = torch.randn((10, 10))
1732        v2 = torch.randn((10, 10))
1733        correct = fn(v1, v2)
1734        cnts = torch._dynamo.testing.CompileCounter()
1735        opt_fn = torch._dynamo.optimize(cnts)(fn)
1736        self.assertEqual(opt_fn(v1, v2), correct)
1737        self.assertEqual(cnts.frame_count, 1)
1738        self.assertEqual(cnts.op_count, 4)
1739
1740    @patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
1741    def test_tensor_item_no_capture(self):
1742        def fn(a, b):
1743            return (a + b).sum().item()
1744
1745        v1 = torch.randn((10, 10))
1746        v2 = torch.randn((10, 10))
1747        correct = fn(v1, v2)
1748        cnts = torch._dynamo.testing.CompileCounter()
1749        opt_fn = torch._dynamo.optimize(cnts)(fn)
1750        self.assertEqual(opt_fn(v1, v2), correct)
1751        self.assertEqual(cnts.frame_count, 1)
1752        self.assertEqual(cnts.op_count, 2)
1753
1754    def test_namedtuple1(self):
1755        def fn(a, b):
1756            tmp = mytuple(a, b, a + b)
1757            return mytuple(tmp.a, tmp[1], tmp.ab + b)
1758
1759        v1 = torch.Tensor([10])
1760        v2 = torch.Tensor([20])
1761        cnts = torch._dynamo.testing.CompileCounter()
1762        opt_fn = torch._dynamo.optimize(cnts)(fn)
1763        self.assertEqual(opt_fn(v1, v2).ab, 50)
1764        self.assertEqual(cnts.frame_count, 1)
1765        self.assertEqual(cnts.op_count, 2)
1766
1767    def test_namedtuple2(self):
1768        def fn(packed):
1769            a, b, c = packed
1770            if hasattr(packed, "b"):
1771                b = packed.b + 1
1772            c = packed[2]
1773            return a + b + c
1774
1775        v1 = torch.Tensor([1])
1776        v2 = torch.Tensor([2])
1777        v3 = torch.Tensor([3])
1778        cnts = torch._dynamo.testing.CompileCounter()
1779        opt_fn = torch._dynamo.optimize(cnts)(fn)
1780        self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7)
1781        self.assertEqual(cnts.frame_count, 1)
1782        self.assertEqual(cnts.op_count, 3)
1783
1784    def test_namedtuple3(self):
1785        def fn(x, packed):
1786            if isinstance(packed, mytuple):
1787                return x + 1
1788            else:
1789                return x - 1
1790
1791        x = torch.rand([2, 3])
1792        packed = mytuple(1, 2, 3)
1793        ref = fn(x, packed)
1794        opt_fn = torch._dynamo.optimize("eager")(fn)
1795        res = opt_fn(x, packed)
1796        self.assertTrue(same(ref, res))
1797
1798    def test_range_input(self):
1799        def fn(a, rng):
1800            x = a
1801            for i in rng:
1802                x = x + i
1803            return x
1804
1805        def fn1(a):
1806            return fn(a, rng=range(3))
1807
1808        return torch._dynamo.testing.standard_test(
1809            self, fn=fn1, nargs=1, expected_ops=3
1810        )
1811
1812    def test_range_with_shape(self):
1813        def fn(a):
1814            for i in range(1, a.shape[0]):
1815                a += 1
1816            return a
1817
1818        return torch._dynamo.testing.standard_test(
1819            self,
1820            fn=fn,
1821            nargs=1,
1822            expected_ops=9,
1823        )
1824
1825    def test_build_tuple_unpack(self):
1826        def fn1(a, b, c):
1827            return a - b / c
1828
1829        def fn2(a, b, c):
1830            tmp1 = (a,)
1831            tmp2 = (b, c)
1832            args = (*tmp1, *tmp2)
1833            return fn1(*args)
1834
1835        def fn3(a, *args):
1836            return fn1(a, *args)
1837
1838        torch._dynamo.testing.standard_test(self, fn=fn2, nargs=3, expected_ops=2)
1839        torch._dynamo.testing.standard_test(self, fn=fn3, nargs=3, expected_ops=2)
1840
1841    def test_list_mul(self):
1842        def fn(count):
1843            head_mask = count * [None] * count
1844            return head_mask
1845
1846        cnts = torch._dynamo.testing.CompileCounter()
1847        opt_fn = torch._dynamo.optimize(cnts)(fn)
1848        self.assertEqual(opt_fn(2), [None] * 4)
1849        # TODO: the captured frame here is a bit goofy, because we don't
1850        # output anything and none of the traced operations have side
1851        # effects.  Probably need better heuristic for bailing on
1852        # dynamo if there are no outputs
1853        if torch._dynamo.config.assume_static_by_default:
1854            self.assertExpectedInline(cnts.frame_count, """0""")
1855            self.assertExpectedInline(cnts.op_count, """0""")
1856        else:
1857            self.assertExpectedInline(cnts.frame_count, """1""")
1858            self.assertExpectedInline(cnts.op_count, """2""")
1859
1860    def test_list_slice_mul(self):
1861        def fn(count):
1862            a = [1, 2, 3]
1863            head_mask = count * a[1:] * count
1864            return head_mask
1865
1866        cnts = torch._dynamo.testing.CompileCounter()
1867        opt_fn = torch._dynamo.optimize(cnts)(fn)
1868        self.assertEqual(opt_fn(2), [2, 3] * 4)
1869        if torch._dynamo.config.assume_static_by_default:
1870            self.assertExpectedInline(cnts.frame_count, """0""")
1871            self.assertExpectedInline(cnts.op_count, """0""")
1872        else:
1873            self.assertExpectedInline(cnts.frame_count, """1""")
1874            self.assertExpectedInline(cnts.op_count, """2""")
1875
1876    def test_tuple_mul(self):
1877        def fn(count):
1878            head_mask = count * (2, 3) * count
1879            return head_mask
1880
1881        cnts = torch._dynamo.testing.CompileCounter()
1882        opt_fn = torch._dynamo.optimize(cnts)(fn)
1883        self.assertEqual(opt_fn(2), (2, 3) * 4)
1884        if torch._dynamo.config.assume_static_by_default:
1885            self.assertExpectedInline(cnts.frame_count, """0""")
1886            self.assertExpectedInline(cnts.op_count, """0""")
1887        else:
1888            self.assertExpectedInline(cnts.frame_count, """1""")
1889            self.assertExpectedInline(cnts.op_count, """2""")
1890
1891    def test_tuple_mul_with_shape(self):
1892        def fn(a):
1893            x = a.shape[0]
1894            y = 2 * (x, 3) * 2
1895            return a + y[4]
1896
1897        # expect 3 ops post folding for dynamic case: size, index, add
1898        torch._dynamo.testing.standard_test(
1899            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 3)
1900        )
1901
1902    def test_tuple_iadd_with_shape(self):
1903        def fn(a):
1904            output = (a + a.shape[0], a - a.shape[0])
1905            # tuple += tuple
1906            output += (a - a.shape[0], a + a.shape[0])
1907            # tuple += constant tuple
1908            output += (2, 3)
1909            return output
1910
1911        # expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
1912        torch._dynamo.testing.standard_test(
1913            self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(4, 12)
1914        )
1915
1916    def test_list_iadd_with_shape(self):
1917        def fn(a):
1918            output = [a + a.shape[0], a - a.shape[0]]
1919            # list += list
1920            output += [a - a.shape[0], a + a.shape[0]]
1921            # list += tuple
1922            output += (a + a.shape[0], a - a.shape[0])
1923            return output
1924
1925        # expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
1926
1927        torch._dynamo.testing.standard_test(
1928            self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(6, 18)
1929        )
1930
1931    def test_list_iadd_side_effect(self):
1932        def fn(a, b):
1933            a += [b]
1934            torch._dynamo.graph_break()
1935            return a
1936
1937        a = [1, 2, 3]
1938        b = torch.ones(2, 2)
1939
1940        opt_fn = torch._dynamo.optimize("eager")(fn)
1941
1942        exp = fn(a, b)
1943
1944        a = [1, 2, 3]
1945        b = torch.ones(2, 2)
1946        act = opt_fn(a, b)
1947
1948        self.assertEqual(exp, act)
1949
1950    def test_user_getattr1(self):
1951        class MyConfig(dict):
1952            def __getattr__(self, name):
1953                return self[name]
1954
1955        def fn(cfg, x, y):
1956            return x + y + cfg.offset
1957
1958        x = torch.randn(10)
1959        cfg = MyConfig(offset=5)
1960        cnts = torch._dynamo.testing.CompileCounter()
1961        opt_fn = torch._dynamo.optimize(cnts)(fn)
1962        self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5))
1963        self.assertEqual(cnts.frame_count, 1)
1964        self.assertEqual(cnts.op_count, 2)
1965
1966    def test_user_getattr2(self):
1967        class MyConfig:
1968            defined_on_class = 1
1969
1970            def __init__(self):
1971                self.defined_on_object = 2
1972
1973            def __getattr__(self, name):
1974                return 3
1975
1976        def fn(cfg, x):
1977            return x + cfg.defined_on_class - cfg.defined_on_object + cfg.not_defined
1978
1979        x = torch.randn(10)
1980        cfg = MyConfig()
1981        cnts = torch._dynamo.testing.CompileCounter()
1982        opt_fn = torch._dynamo.optimize(cnts)(fn)
1983        self.assertTrue(same(opt_fn(cfg, x), x + 1 - 2 + 3))
1984        self.assertEqual(cnts.frame_count, 1)
1985        self.assertEqual(cnts.op_count, 3)
1986
1987    def test_getset_descriptor(self):
1988        def fn(g, x):
1989            return g.__get__(x)
1990
1991        cnts = torch._dynamo.testing.CompileCounter()
1992        opt_fn = torch.compile(fullgraph=True, backend="eager")(fn)
1993        g = torch.Tensor.shape
1994
1995        res = opt_fn(g, torch.ones(2, 2))
1996        exp_res = fn(g, torch.ones(2, 2))
1997        self.assertEqual(res, exp_res)
1998
1999    def test_get_attr_function(self):
2000        def fn(g, x):
2001            return g(x)
2002
2003        cnts = torch._dynamo.testing.CompileCounter()
2004        opt_fn = torch._dynamo.optimize(cnts)(fn)
2005        g = torch.Tensor.shape.__get__
2006
2007        res = opt_fn(g, torch.ones(2, 2))
2008        exp_res = fn(g, torch.ones(2, 2))
2009        self.assertEqual(res, exp_res)
2010
2011    def test_user_getattribute(self):
2012        class MyObject:
2013            def __init__(self):
2014                self.custom_dict = {"a": torch.rand((2, 2))}
2015                self.my_number = 42
2016
2017            def __getattribute__(self, name):
2018                custom_dict = super().__getattribute__("custom_dict")
2019                if name in custom_dict:
2020                    return custom_dict[name]
2021                return super().__getattribute__(name)
2022
2023            def run(self, x):
2024                return self.my_number * x + self.a * x
2025
2026        def fn(obj, x):
2027            return obj.run(x)
2028
2029        obj = MyObject()
2030        x = torch.rand((2, 2))
2031        cnts = torch._dynamo.testing.CompileCounter()
2032        opt_fn = torch._dynamo.optimize(cnts)(fn)
2033        self.assertTrue(same(opt_fn(obj, x), fn(obj, x)))
2034
2035    def test_nn_module_getattr(self):
2036        class MyMod(torch.nn.Module):
2037            def __init__(self):
2038                super().__init__()
2039                self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
2040                self.other_attr = torch.rand((2, 2))
2041
2042            def __getattr__(self, name):
2043                custom_dict = self.custom_dict
2044                if name in custom_dict:
2045                    return custom_dict[name]
2046                return super().__getattr__(name)
2047
2048            def forward(self, x):
2049                return x @ self.other_attr + self.queue[-1]
2050
2051        x = torch.rand((2, 2))
2052        mod = MyMod()
2053        cnts = torch._dynamo.testing.CompileCounter()
2054        opt_mod = torch._dynamo.optimize(cnts)(mod)
2055        self.assertTrue(same(opt_mod(x), mod(x)))
2056        self.assertTrue(cnts.frame_count, 1)
2057        self.assertTrue(cnts.op_count, 2)
2058
2059    def test_nn_module_getattribute(self):
2060        class MyMod(torch.nn.Module):
2061            def __init__(self):
2062                super().__init__()
2063                self.my_number = 42
2064
2065            def __getattribute__(self, name):
2066                if name == "special_attr":
2067                    return torch.tensor([[1, 2], [3, 4]])
2068                return super().__getattribute__(name)
2069
2070            def forward(self, x):
2071                return self.my_number * x + self.special_attr * x
2072
2073        def fn(mod, x):
2074            return mod(x)
2075
2076        mod = MyMod()
2077        x = torch.rand((2, 2))
2078        cnts = torch._dynamo.testing.CompileCounter()
2079        opt_fn = torch._dynamo.optimize(cnts)(fn)
2080        self.assertTrue(same(opt_fn(mod, x), fn(mod, x)))
2081
2082    def test_constant_getattr(self):
2083        # https://github.com/pytorch/pytorch/issues/97480
2084        def fn():
2085            return getattr(None, "arg", 3)
2086
2087        cnt = torch._dynamo.testing.CompileCounter()
2088        optimized_fn = torch._dynamo.optimize(cnt)(fn)
2089        res = optimized_fn()
2090        self.assertTrue(same(res, 3))
2091
2092    def test_user_property(self):
2093        class MyConfig:
2094            @property
2095            def prop5(self):
2096                return 5
2097
2098        def fn(cfg, x, y):
2099            return x + y + cfg.prop5
2100
2101        x = torch.randn(10)
2102        cfg = MyConfig()
2103        cnts = torch._dynamo.testing.CompileCounter()
2104        opt_fn = torch._dynamo.optimize(cnts)(fn)
2105        self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5))
2106        self.assertEqual(cnts.frame_count, 1)
2107        self.assertEqual(cnts.op_count, 2)
2108
2109    def test_dataclass_fields(self):
2110        @dataclasses.dataclass
2111        class MyDataClass:
2112            a: torch.Tensor
2113            b: torch.Tensor = None
2114            c: torch.Tensor = None
2115            d: torch.Tensor = None
2116            e: torch.Tensor = None
2117
2118        def fn(obj):
2119            class_fields = dataclasses.fields(obj)
2120            assert len(class_fields)
2121            assert all(field.default is None for field in class_fields[1:])
2122            other_fields_are_none = all(
2123                getattr(obj, field.name) is None for field in class_fields[1:]
2124            )
2125            assert not other_fields_are_none
2126
2127            if not hasattr(obj, "a"):
2128                return -1
2129            if hasattr(obj, "z"):
2130                return -2
2131
2132            total = getattr(obj, class_fields[0].name)
2133            for field in class_fields[1:]:
2134                v = getattr(obj, field.name)
2135                if v is not None:
2136                    total += v
2137
2138            return total
2139
2140        obj1 = MyDataClass(torch.randn(10), torch.randn(10), torch.randn(10))
2141        obj2 = MyDataClass(torch.randn(10), e=torch.randn(10))
2142        correct1 = fn(obj1)
2143        correct2 = fn(obj2)
2144
2145        cnts = torch._dynamo.testing.CompileCounter()
2146        opt_fn = torch._dynamo.optimize(cnts)(fn)
2147        self.assertTrue(same(opt_fn(obj1), correct1))
2148        self.assertEqual(cnts.frame_count, 1)
2149        self.assertEqual(cnts.op_count, 2)
2150
2151        torch._dynamo.reset()
2152        cnts = torch._dynamo.testing.CompileCounter()
2153        opt_fn = torch._dynamo.optimize(cnts)(fn)
2154        self.assertTrue(same(opt_fn(obj2), correct2))
2155        self.assertEqual(cnts.frame_count, 1)
2156        self.assertEqual(cnts.op_count, 1)
2157
2158        # guard failure
2159        obj2.z = True
2160        self.assertEqual(opt_fn(obj2), -2)
2161
2162    def test_dataclass_local_hasattr(self):
2163        cnt = CompileCounter()
2164        x = torch.randn(10)
2165
2166        @dataclasses.dataclass
2167        class MyDataClass:
2168            a: torch.Tensor
2169            b: torch.Tensor
2170
2171        @torch.compile(backend=cnt, fullgraph=True)
2172        def fn():
2173            obj = MyDataClass(x + 1, x - 1)
2174            if not hasattr(obj, "a"):
2175                return -1
2176            if hasattr(obj, "z"):
2177                return -2
2178            return obj
2179
2180        result = fn()
2181        self.assertIsInstance(result, MyDataClass)
2182        self.assertEqual(result.a, x + 1)
2183        self.assertEqual(result.b, x - 1)
2184        self.assertEqual(cnt.frame_count, 1)
2185        self.assertEqual(cnt.op_count, 2)
2186
2187    def test_catch_watchings1(self):
2188        cnt = CompileCounter()
2189
2190        @torch.compile(backend=cnt, fullgraph=True)
2191        def fn(x):
2192            with warnings.catch_warnings(record=True):
2193                return x.sin()
2194
2195        x = torch.randn(8)
2196        self.assertEqual(fn(x), x.sin())
2197        self.assertEqual(cnt.frame_count, 1)
2198
2199    def test_catch_watchings2(self):
2200        cnt = CompileCounter()
2201
2202        @torch.compile(backend=cnt, fullgraph=True)
2203        def fn(x):
2204            return x.sin(), warnings.catch_warnings(record=True)
2205
2206        x = torch.randn(8)
2207        _, a = fn(x)
2208        _, b = fn(x)
2209        self.assertEqual(cnt.frame_count, 1)
2210        self.assertIsInstance(a, warnings.catch_warnings)
2211        self.assertIsInstance(b, warnings.catch_warnings)
2212        self.assertIsNot(a, b)
2213
2214    def test_tensor_build_list_unpack(self):
2215        def fn(x):
2216            # seen in fastNLP_Bert
2217            return torch.cat([*x], dim=-1)
2218
2219        val = torch.randn([1, 1, 473, 768])
2220        correct = fn(val)
2221        cnts = torch._dynamo.testing.CompileCounter()
2222        opt_fn = torch._dynamo.optimize(cnts)(fn)
2223        self.assertTrue(same(opt_fn(val), correct))
2224        self.assertEqual(cnts.frame_count, 1)
2225        self.assertEqual(cnts.op_count, 2)
2226
2227    def test_numpy_int_constant(self):
2228        def fn(x, a, b):
2229            return x + (a % b)
2230
2231        args = [torch.randn(10), 4096, np.int64(8)]
2232        correct = fn(*args)
2233        cnts = torch._dynamo.testing.CompileCounter()
2234        opt_fn = torch._dynamo.optimize(cnts, dynamic=True, nopython=True)(fn)
2235        self.assertTrue(same(opt_fn(*args), correct))
2236        self.assertTrue(same(opt_fn(*args), correct))
2237        self.assertEqual(cnts.frame_count, 1)
2238        self.assertEqual(cnts.op_count, 2)
2239
2240    def test_numpy_subdtype(self):
2241        def fn(x, n):
2242            return np.issubdtype(type(n), np.integer) + x
2243
2244        args = [torch.randn(10), 4096]
2245        correct = fn(*args)
2246        cnts = torch._dynamo.testing.CompileCounter()
2247        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2248        self.assertEqual(opt_fn(*args), correct)
2249        self.assertEqual(cnts.frame_count, 1)
2250
2251    def test_numpy_take_along_axis(self):
2252        def fn(x, i, a):
2253            return np.take_along_axis(x, i, a)
2254
2255        def sample_to_args(s):
2256            args = (s.input, *sample.args)
2257            return tuple(a.numpy() if isinstance(a, torch.Tensor) else a for a in args)
2258
2259        samples = list(
2260            sample_inputs_take_along_dim(
2261                None, "cpu", torch.float32, requires_grad=False
2262            )
2263        )
2264        cnts = torch._dynamo.testing.CompileCounter()
2265        opt_fn = torch._dynamo.optimize(cnts)(fn)
2266        i = 1
2267        for sample in samples:
2268            args = sample_to_args(sample)
2269            if len(args) < 3:
2270                # if axis is None, second argument is treated as 1d array
2271                args = (args[0], np.ravel(args[1]), None)
2272            self.assertEqual(fn(*args), opt_fn(*args))
2273            self.assertEqual(cnts.frame_count, i)
2274            i += 1
2275
2276    def test_numpy_torch_operators(self):
2277        def fn(op, t1, t2):
2278            return op(t1, t2)
2279
2280        from torch._dynamo.variables.builtin import BuiltinVariable
2281
2282        operators = BuiltinVariable._fx_graph_functions()
2283
2284        for op, t1_np, t2_np in itertools.product(
2285            operators, (True, False), (True, False)
2286        ):
2287            if op in [operator.eq, operator.ne]:
2288                # returns equivalent of torch.eq/ne
2289                continue
2290            if op is operator.getitem:
2291                # skip
2292                # Did you know that tensor[ndarray_of_floats] works?
2293                continue
2294            if op is operator.imatmul and (t1_np or t2_np):
2295                # skip
2296                # in numpy, in place matmul does not work single
2297                # dimensional arrays
2298                continue
2299            t1 = torch.rand(5)
2300            if t1_np:
2301                t1 = t1.numpy()
2302            t2 = torch.rand(5)
2303            if t2_np:
2304                t2 = t2.numpy()
2305            try:
2306                # TODO try a bit harder
2307                result = op(t1, t2)
2308            except (RuntimeError, TypeError, IndexError):
2309                continue
2310            cnts = torch._dynamo.testing.CompileCounter()
2311            opt_fn = torch._dynamo.optimize(cnts)(fn)
2312            self.assertEqual(result, opt_fn(op, t1, t2), msg=f"{op=} {t1_np=} {t2_np=}")
2313            self.assertEqual(cnts.frame_count, 1, msg=f"{op=} {t1_np=} {t2_np=}")
2314            torch._dynamo.reset()
2315
2316    def test_numpy_ndarray_graph_break(self):
2317        def fn(x):
2318            a = x.numpy()
2319            b = a.real
2320            torch._dynamo.graph_break()
2321            c = np.multiply(b, 2.0)
2322            return c
2323
2324        cnts = torch._dynamo.testing.CompileCounter()
2325        opt_fn = torch._dynamo.optimize(cnts)(fn)
2326        for _ in range(10):
2327            x = torch.randn(3)
2328            ref = fn(x)
2329            res = opt_fn(x)
2330            self.assertEqual(ref, res)
2331        self.assertEqual(cnts.frame_count, 2)
2332
2333    def test_numpy_ndarray_graph_break_with_multiple_outputs(self):
2334        def fn(x, y):
2335            a = x.numpy()
2336            b = y.numpy()
2337            torch._dynamo.graph_break()
2338            return np.add(a, 1), np.add(b, 1)
2339
2340        cnts = torch._dynamo.testing.CompileCounter()
2341        opt_fn = torch._dynamo.optimize(cnts)(fn)
2342        for _ in range(10):
2343            x = torch.randn([1, 3])
2344            y = torch.randn([1, 3])
2345            ref = fn(x, y)
2346            res = opt_fn(x, y)
2347            self.assertEqual(ref, res)
2348        self.assertEqual(cnts.frame_count, 2)
2349
2350    def test_numpy_force(self):
2351        def fn(x):
2352            return x.numpy(force=False)
2353
2354        cnts = torch._dynamo.testing.CompileCounter()
2355        opt_fn = torch._dynamo.optimize(cnts)(fn)
2356        x = torch.randn(3)
2357        res = opt_fn(x)
2358        self.assertEqual(type(res), np.ndarray)
2359        self.assertEqual(cnts.frame_count, 1)
2360
2361        def fn(x):
2362            return x.numpy(force=True)
2363
2364        cnts = torch._dynamo.testing.CompileCounter()
2365        opt_fn = torch._dynamo.optimize(cnts)(fn)
2366        x = torch.randn(3, requires_grad=True)
2367        res = opt_fn(x)
2368        self.assertEqual(type(res), np.ndarray)
2369        self.assertEqual(cnts.frame_count, 1)
2370
2371    def test_numpy_recompilation_scalar(self):
2372        def fn(x, a):
2373            return np.where(x < 0.5, a, x)
2374
2375        x = np.random.randn(8)
2376        cnts = torch._dynamo.testing.CompileCounter()
2377        opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn)
2378
2379        ref = fn(x, 3)
2380        res = opt_fn(x, 3)
2381        self.assertEqual(ref, res)
2382
2383        ref = fn(x, 4)
2384        res = opt_fn(x, 4)
2385        self.assertEqual(ref, res)
2386
2387        self.assertEqual(cnts.frame_count, 1)
2388
2389    def test_tensor_interacts_with_numpy_ndarray(self):
2390        def fn(x, y):
2391            a = x.numpy()
2392            b = y.numpy()
2393            c = np.ones_like(a)
2394            d = np.ones_like(b)
2395            torch._dynamo.graph_break()
2396            return np.add(a, c), np.add(b, d)
2397
2398        cnts = torch._dynamo.testing.CompileCounter()
2399        opt_fn = torch._dynamo.optimize(cnts)(fn)
2400        for _ in range(10):
2401            x = torch.randn([1, 3])
2402            y = torch.randn([1, 3])
2403            ref = fn(x, y)
2404            res = opt_fn(x, y)
2405            self.assertEqual(ref, res)
2406        self.assertEqual(cnts.frame_count, 2)
2407
2408    def test_numpy_ndarray_works_with_builtin_function(self):
2409        def fn(x):
2410            v = x.sum() / len(x)
2411            return v
2412
2413        cnts = torch._dynamo.testing.CompileCounter()
2414        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2415        for _ in range(10):
2416            x = np.random.randn(2, 3)
2417            ref = fn(x)
2418            res = opt_fn(x)
2419            self.assertEqual(ref, res)
2420        self.assertEqual(cnts.frame_count, 1)
2421
2422    def test_numpy_array_of_arrays(self):
2423        def fn(x, y):
2424            return np.array([x, y])
2425
2426        cnts = torch._dynamo.testing.CompileCounter()
2427        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2428
2429        x, y = np.float64(1), np.float64(2)
2430        res = opt_fn(x, y)
2431        self.assertEqual(res, np.array([1, 2], dtype=float))
2432        self.assertEqual(type(res), np.ndarray)
2433        self.assertEqual(cnts.frame_count, 1)
2434
2435        x, y = np.arange(2), np.arange(2) + 2
2436        res = opt_fn(x, y)
2437        self.assertEqual(res, np.array([[0, 1], [2, 3]]))
2438        self.assertEqual(type(res), np.ndarray)
2439        self.assertEqual(cnts.frame_count, 2)
2440
2441    def test_numpy_readonly(self):
2442        @torch.compile(fullgraph=True)
2443        def fn(x):
2444            return x
2445
2446        x = np.broadcast_to(np.arange(3), (2, 3))
2447        self.assertFalse(x.flags.writeable)
2448
2449        with warnings.catch_warnings():
2450            warnings.simplefilter("error")
2451            y = fn(x)
2452        self.assertTrue(y.flags.writeable)  # XXX: differs from numpy
2453
2454    def test_numpy_tolist(self):
2455        def fn(x):
2456            return x.tolist()
2457
2458        cnts = torch._dynamo.testing.CompileCounter()
2459        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2460
2461        x = np.arange(5)
2462        r = opt_fn(x)
2463
2464        self.assertEqual(r, [0, 1, 2, 3, 4])
2465        self.assertEqual(type(r), list)
2466        self.assertEqual(cnts.frame_count, 1)
2467
2468    def test_numpy_size_attr(self):
2469        def fn(x):
2470            return x.size + x
2471
2472        cnts = torch._dynamo.testing.CompileCounter()
2473        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2474
2475        x = np.arange(5)
2476        r = opt_fn(x)
2477
2478        self.assertEqual(r, fn(x))
2479        self.assertEqual(type(r), np.ndarray)
2480        self.assertEqual(cnts.frame_count, 1)
2481
2482    def test_numpy_no_raise(self):
2483        def _inf_nan_preprocess(t, t_np):
2484            t_np = np.nan_to_num(t_np)
2485            return t, t_np
2486
2487        def fn():
2488            # shape, dims format
2489            test_cases = (
2490                (3, 3),
2491                (4, 4),
2492                (5, 5),
2493            )
2494
2495            for shape in test_cases:
2496                t = torch.randn(shape, dtype=torch.complex64)
2497                t_np = np.random.randn(*shape).astype(np.complex64)
2498
2499                _, t_np = _inf_nan_preprocess(t, t_np)
2500                print(t, t_np)  # Just a side effect so that compilation kicks in
2501
2502        cnt = CompileCounterWithBackend("inductor")
2503        fn = torch._dynamo.optimize(cnt)(fn)
2504        fn()
2505        self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))
2506
2507    def test_mandelbrot_numpy(self):
2508        def mandelbrot_numpy(max_iter):
2509            # Define the boundaries of the complex plane
2510            xn = 450
2511            yn = 375
2512            xmin = -2.25
2513            xmax = 0.75
2514            ymin = -1.25
2515            ymax = 1.25
2516
2517            # Create the grid of complex numbers
2518            x_values = np.linspace(xmin, xmax, xn, dtype=np.float64)
2519            y_values = np.linspace(ymin, ymax, yn, dtype=np.float64)
2520            rx, iy = np.meshgrid(x_values, y_values, indexing="xy")
2521
2522            x = rx.copy()
2523            y = iy.copy()
2524            mask = np.zeros_like(x)
2525            for i in range(max_iter):
2526                x_prev = x
2527                y_prev = y
2528                x = x_prev**2 - y_prev**2 + rx
2529                y = 2 * x_prev * y_prev + iy
2530                inside = np.sqrt(x**2 + y**2) <= 2
2531                mask += inside
2532            return mask
2533
2534        cnts = torch._dynamo.testing.CompileCounter()
2535        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(mandelbrot_numpy)
2536        n_iter = torch._dynamo.config.cache_size_limit - 2
2537        for i in range(n_iter):
2538            x = i + 3
2539            ref = mandelbrot_numpy(x)
2540            res = opt_fn(x)
2541            self.assertEqual(ref, res)
2542        # We need to specialise the number as it's in a forloop
2543        self.assertEqual(cnts.frame_count, n_iter)
2544
2545    def test_numpy_as_global(self):
2546        global x
2547        x = np.arange(10)
2548
2549        @torch.compile(fullgraph=True)
2550        def fn(y):
2551            return y + x + x
2552
2553        r = fn(np.arange(10))
2554        self.assertEqual(type(r), np.ndarray)
2555        self.assertEqual(r, x * 3)
2556        del x
2557
2558    def test_numpy_gt(self):
2559        x = np.arange(10)
2560
2561        @torch.compile
2562        def fn(y):
2563            return y >= 3
2564
2565        r = fn(x)
2566        self.assertEqual(type(r), np.ndarray)
2567        self.assertEqual(r, x >= 3)
2568
2569    def test_numpy_min(self):
2570        x = np.arange(10)
2571
2572        @torch.compile
2573        def fn(y):
2574            return min(y, 3), min(y, y - 1)
2575
2576        r1, r2 = fn(x)
2577        self.assertEqual(type(r1), np.ndarray)
2578        self.assertEqual(type(r2), np.ndarray)
2579        self.assertEqual(r1, np.minimum(x, 3))
2580        self.assertEqual(r2, np.minimum(x, x - 1))
2581
2582    def test_graph_break_correctly_when_passing_numpy_ndarray_to_torch_function(self):
2583        # from transformers/models/big_bird/modeling_big_bird.py
2584        def fn(x: int, y: torch.Tensor):
2585            ndarray_list = [np.ones([2, x])]
2586            ndarray = np.stack(ndarray_list, axis=0)
2587            tensor = torch.tensor(ndarray, dtype=torch.long)
2588            tensor.unsqueeze_(0)
2589            return tensor + y
2590
2591        cnts = torch._dynamo.testing.CompileCounter()
2592        opt_fn = torch._dynamo.optimize(cnts)(fn)
2593        for x in range(1, 10):
2594            y = torch.randn([1, 2, x])
2595            ref = fn(x, y)
2596            res = opt_fn(x, y)
2597            self.assertEqual(ref, res)
2598        # It's all traced once with x = 1, x = 2 and then x = ks0
2599        # For dynamic it's x=1 and x=ks0
2600        self.assertEqual(cnts.frame_count, ifdynstaticdefault(3, 2))
2601
2602    def test_numpy_with_builtin_type(self):
2603        x = np.random.rand(5)
2604
2605        def fn(x):
2606            return (x * 5).astype(bool).astype(float).astype(int) + 8
2607
2608        cnts = torch._dynamo.testing.CompileCounter()
2609        opt_fn = torch._dynamo.optimize(cnts)(fn)
2610
2611        r = opt_fn(x)
2612        self.assertEqual(r.dtype, int)
2613        self.assertEqual(cnts.frame_count, 1)
2614
2615    def test_with_builtin_type(self):
2616        x = torch.randn(5)
2617
2618        def fn(x):
2619            return (x * 5).to(bool).to(float).to(int) + 8
2620
2621        cnts = torch._dynamo.testing.CompileCounter()
2622        opt_fn = torch._dynamo.optimize(cnts)(fn)
2623
2624        r = opt_fn(x)
2625        self.assertEqual(r.dtype, torch.int64)
2626        self.assertEqual(cnts.frame_count, 1)
2627
2628    def test_numpy_unique_f16(self):
2629        def fn():
2630            x = np.asarray([1, 1, 2, 2, 3], dtype=np.float16)
2631            return np.unique(x)
2632
2633        cnts = torch._dynamo.testing.CompileCounter()
2634        opt_fn = torch._dynamo.optimize(cnts)(fn)
2635
2636        r = opt_fn()
2637        self.assertEqual(r.dtype, np.float16)
2638        self.assertEqual(cnts.frame_count, 1)
2639
2640    def test_numpy_fallback_on_eager(self):
2641        def fn():
2642            return np.asarray(["L", "U"])
2643
2644        cnts = torch._dynamo.testing.CompileCounter()
2645        opt_fn = torch._dynamo.optimize(cnts)(fn)
2646
2647        r = opt_fn()
2648        self.assertEqual(cnts.frame_count, 0)  # graph break
2649        self.assertEqual(r, np.asarray(["L", "U"]))
2650
2651        # repeat with a different function
2652        def fn2():
2653            return np.random.choice(["L", "U"])
2654
2655        cnts2 = torch._dynamo.testing.CompileCounter()
2656        opt_fn2 = torch._dynamo.optimize(cnts2)(fn2)
2657
2658        r2 = fn2()
2659        self.assertEqual(cnts.frame_count, 0)
2660        assert r2 in ("L", "U")
2661
2662    def test_trace_ndarray_frame(self):
2663        def fn(x):
2664            x = x**2
2665            print("graph break.")
2666            return 2 * x
2667
2668        counter = CompileCounter()
2669        compiled_fn = torch._dynamo.optimize(counter)(fn)
2670
2671        x = np.arange(8)
2672        self.assertEqual(fn(x), compiled_fn(x))
2673        self.assertEqual(counter.frame_count, 2)
2674
2675    def test_trace_ndarray_frame_2(self):
2676        # no tensors/ndarray as inputs in the frame
2677        def fn(x):
2678            print("graph break.")
2679            return 2 * np.arange(x)
2680
2681        counter = CompileCounter()
2682        compiled_fn = torch._dynamo.optimize(counter)(fn)
2683
2684        x = 8
2685        self.assertEqual(fn(x), compiled_fn(x))
2686        self.assertEqual(counter.frame_count, 1)
2687
2688    def test_numpy_non_torch_dtype(self):
2689        # test that we gracefully graph break on dtypes
2690        # that do not have pytorch equivalents.
2691        def fn(x):
2692            return isinstance(x, torch.Tensor)
2693
2694        cnts = torch._dynamo.testing.CompileCounter()
2695        opt_fn = torch._dynamo.optimize(cnts)(fn)
2696
2697        # torch does not have the `uint16` dtype
2698        for x in [np.array([42], dtype=np.uint16), np.uint16(42), np.dtype("uint16")]:
2699            r = opt_fn(x)
2700
2701            self.assertEqual(r, False)
2702            self.assertEqual(cnts.frame_count, 0)  # graph break
2703
2704    def test_numpy_iter(self):
2705        # test that iteration over an ndarray produces ndarrays not bare tensors
2706        def fn(x):
2707            return [bm for bm in x]
2708
2709        cnts = torch._dynamo.testing.CompileCounter()
2710        opt_fn = torch._dynamo.optimize(cnts)(fn)
2711
2712        proba_map = np.arange(3)[:, None]
2713        res = opt_fn(proba_map)
2714
2715        self.assertEqual([type(r) for r in res], [np.ndarray, np.ndarray, np.ndarray])
2716        self.assertEqual(res, [np.array([0]), np.array([1]), np.array([2])])
2717        self.assertEqual(cnts.frame_count, 1)
2718
2719    # cache size limit needs to be larger than the `dtypes` list size
2720    @torch._dynamo.config.patch(cache_size_limit=12)
2721    def test_dtypes_no_graphbreaks(self):
2722        dtypes = [
2723            # floats
2724            float,
2725            np.float64,
2726            "float64",
2727            np.float32,
2728            "float32",
2729            # np.dtype('float64')   # XXX: this is not supported, yet
2730            # integers
2731            int,
2732            "int",
2733            np.intp,
2734            np.int32,
2735            np.uint8
2736            # np.dtype('int')       # XXX: as above
2737        ]
2738
2739        def fn(dt):
2740            return np.arange(5, dtype=dt)
2741
2742        for dtyp in dtypes:
2743            cnts = torch._dynamo.testing.CompileCounter()
2744            opt_fn = torch._dynamo.optimize(cnts)(fn)
2745
2746            val = fn(dtyp)
2747            opt_val = opt_fn(dtyp)
2748
2749            self.assertEqual(cnts.frame_count, 1)  # no graph break
2750
2751    # setting the config value makes the PRNG identical to numpy's
2752    # NB this may involve a graph break
2753    @torch._dynamo.config.patch(use_numpy_random_stream=True)
2754    def test_numpy_random_config_to_numpy(self):
2755        @torch.compile
2756        def fn():
2757            return np.random.uniform(size=13)
2758
2759        self.assertEqual(fn().shape, (13,))
2760
2761    def test_inplace_view_on_graph_input(self):
2762        # graph break when calling methods with inplace_view tag on graph input
2763        func_args_map = {
2764            lambda x: x.resize_(6).mul_(2): torch.ones(4),
2765            lambda x: x.t_().mul_(2): torch.rand(2, 3),
2766            lambda x: x.transpose_(0, 1).mul_(2): torch.rand(2, 3),
2767            lambda x: x.squeeze_().mul_(2): torch.rand(1, 2, 3),
2768            lambda x: x.unsqueeze_(0).mul_(2): torch.rand(2, 3),
2769            lambda x: x.resize_as_(torch.rand(200, 300)): torch.rand(2, 3),
2770            lambda x: x.swapaxes_(0, 1).mul_(2): torch.rand(2, 3),
2771            lambda x: x.swapdims_(0, 1).mul_(2): torch.rand(2, 3),
2772            lambda x: x.rename_("N", "C").mul_(2): torch.zeros(2, 3),
2773            lambda x: x.as_strided_((3, 2), (2, 1)).mul_(2): torch.zeros(2, 3),
2774            lambda x: x.detach_().mul_(2): torch.zeros(2, 3),
2775        }
2776        for func, args in func_args_map.items():
2777            args_clone = args.clone()
2778            cnts = torch._dynamo.testing.CompileCounter()
2779            opt_f = torch._dynamo.optimize(cnts)(func)
2780            self.assertTrue(same(func(args).shape, opt_f(args_clone).shape))
2781            self.assertEqual(cnts.frame_count, 1)
2782            self.assertEqual(cnts.op_count, 1)  # mul_
2783
2784    def test_out_variants_with_resizing_on_graph_inputs(self):
2785        def fn(x, y):
2786            return torch.cosh(x, out=y) + 1
2787
2788        x = torch.rand(2, 3)
2789        y = torch.rand(4)
2790
2791        cnts = torch._dynamo.testing.CompileCounter()
2792        opt_fn = torch.compile(fn, backend=cnts)
2793        self.assertTrue(same(fn(x, y), opt_fn(x.clone(), y.clone())))
2794        self.assertEqual(cnts.frame_count, 1)
2795
2796    def test_out_variants_with_resizing_on_graph_inputs_with_dynamic(self):
2797        # https://github.com/pytorch/pytorch/issues/120482
2798        class CustomModel(torch.nn.Module):
2799            def __init__(self):
2800                super().__init__()
2801
2802            def forward(self, inputs):
2803                return torch.outer(**inputs)
2804
2805        compile_fn = torch.compile(CustomModel(), fullgraph=True)
2806
2807        shapes = [(2, 1), (6, 1), (4, 1)]
2808        for shape in shapes:
2809            vec1, vec2 = shape
2810            input_tensor1 = torch.randn(vec1)
2811            input_tensor2 = torch.randn(vec2)
2812            out_tensor = torch.empty(shape)
2813            args = {"input": input_tensor1, "vec2": input_tensor2, "out": out_tensor}
2814            res = compile_fn(args)
2815            opt_res = res.clone()  # cuz this is out and we mutate it
2816            res = CustomModel()(args)
2817            self.assertEqual(res, opt_res)
2818
2819    def test_dict_mutation_side_effect(self):
2820        def fn(d):
2821            d["c"] = d["a"] + d.pop("b")
2822            return d
2823
2824        args1 = {"a": torch.randn(10), "b": torch.randn(10)}
2825        args2 = dict(args1)
2826        assert fn(args1) is args1
2827        cnts = torch._dynamo.testing.CompileCounter()
2828        opt_fn = torch._dynamo.optimize(cnts)(fn)
2829        self.assertIs(opt_fn(args2), args2)
2830        self.assertTrue(same(args1, args2))
2831        self.assertEqual(cnts.frame_count, 1)
2832        self.assertEqual(cnts.op_count, 1)
2833
2834    def test_dict_order_keys(self):
2835        def fn(d):
2836            c = 0
2837            for v in d.values():
2838                c += v
2839            return c
2840
2841        args1 = {}
2842        args1["a"] = torch.rand(10)
2843        args1["b"] = torch.rand(10)
2844        cnts = torch._dynamo.testing.CompileCounter()
2845        opt_fn = torch._dynamo.optimize(cnts)(fn)
2846        self.assertEqual(fn(args1), opt_fn(args1))
2847        self.assertEqual(cnts.frame_count, 1)
2848        self.assertEqual(cnts.op_count, 2)
2849
2850        # A different order of keys recompiles
2851        args2 = {}
2852        args2["b"] = args1["b"]
2853        args2["a"] = args1["a"]
2854        self.assertEqual(fn(args2), opt_fn(args2))
2855        self.assertEqual(cnts.frame_count, 2)
2856        # Extra calls don't recompile
2857        self.assertEqual(cnts.frame_count, 2)
2858
2859    def test_dict_namedtuple(self):
2860        def fn(d):
2861            return d[3] * 2
2862
2863        args1 = {collections.namedtuple: None, 3: torch.randn(3)}
2864        cnts = torch._dynamo.testing.CompileCounter()
2865        opt_fn = torch._dynamo.optimize(cnts)(fn)
2866        self.assertEqual(fn(args1), opt_fn(args1))
2867        self.assertEqual(cnts.frame_count, 1)
2868        # Test a failing namedtuple guard
2869        args2 = {2: None, 3: torch.randn(3)}
2870        self.assertEqual(fn(args2), opt_fn(args2))
2871        self.assertEqual(cnts.frame_count, 2)
2872
2873    def test_dict_order_keys_tensors(self):
2874        def fn(d, x):
2875            return d[x] + 3
2876
2877        args1 = {}
2878        x = torch.randn(10)
2879        y = torch.randn(10)
2880        z = torch.randn(10)
2881        args1[x] = y
2882        args1[3] = z
2883
2884        cnts = torch._dynamo.testing.CompileCounter()
2885        opt_fn = torch._dynamo.optimize(cnts)(fn)
2886        self.assertEqual(fn(args1, x), opt_fn(args1, x))
2887        self.assertEqual(cnts.frame_count, 1)
2888
2889        # Calling again doesn't recompile (same id and key order)
2890        opt_fn(args1, x)
2891        self.assertEqual(cnts.frame_count, 1)
2892        args2 = {}
2893        args2[3] = z
2894        args2[x] = y
2895
2896        # Different order recompiles
2897        self.assertEqual(fn(args2, x), opt_fn(args2, x))
2898        self.assertEqual(cnts.frame_count, 2)
2899
2900    def test_dict_order_keys_modules(self):
2901        def fn(d, x):
2902            return d[x](torch.ones(2, 2))
2903
2904        args1 = {}
2905        x = torch.nn.Linear(2, 2)
2906        y = torch.nn.Linear(2, 2)
2907        z = torch.nn.Linear(2, 2)
2908        args1[x] = y
2909        args1[3] = z
2910
2911        cnts = torch._dynamo.testing.CompileCounter()
2912        opt_fn = torch._dynamo.optimize(cnts)(fn)
2913        self.assertEqual(fn(args1, x), opt_fn(args1, x))
2914        self.assertEqual(cnts.frame_count, 1)
2915
2916        # Calling again doesn't recompile (same id and key order)
2917        opt_fn(args1, x)
2918        self.assertEqual(cnts.frame_count, 1)
2919        args2 = {}
2920        args2[3] = z
2921        args2[x] = y
2922
2923        # Different order recompiles
2924        self.assertEqual(fn(args2, x), opt_fn(args2, x))
2925        self.assertEqual(cnts.frame_count, 2)
2926
2927    def test_dunder_new_function_inlining(self):
2928        # https://github.com/pytorch/pytorch/issues/107460
2929
2930        counters.clear()
2931
2932        class ModelA(torch.nn.Module):
2933            def __init__(self):
2934                super().__init__()
2935
2936            def forward(self, x):
2937                return torch.tanh(x + 1)
2938
2939        class ModelB(torch.nn.Module):
2940            def __new__(cls):
2941                return ModelA()
2942
2943        class Model(torch.nn.Module):
2944            def __init__(self):
2945                super().__init__()
2946                self.layer = torch.nn.Linear(2, 2)
2947
2948            def forward(self, x):
2949                other = ModelB()
2950                return self.layer(x) + other(x)
2951
2952        x = torch.rand(2, 2)
2953        m = Model()
2954
2955        opt_m = torch.compile(backend="eager")(m)
2956        ref = m(x)
2957        res = opt_m(x)
2958        self.assertTrue(same(ref, res))
2959        self.assertEqual(len(counters["graph_break"]), 1)
2960        self.assertFalse("super() nn.Module.__init__" in counters["graph_break"])
2961
2962    def test_class_duner_mro(self):
2963        class ModuleA(torch.nn.Module):
2964            pass
2965
2966        class ModuleB(ModuleA):
2967            pass
2968
2969        def fn(x, mod):
2970            if ModuleA in type(mod).__mro__:
2971                return x + 1
2972            else:
2973                return x - 1
2974
2975        x = torch.rand(2, 3)
2976        mod = ModuleB()
2977        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
2978        ref = fn(x, mod)
2979        res = opt_fn(x, mod)
2980        self.assertTrue(same(ref, res))
2981
2982    def test_nested_wraps(self):
2983        def foo(x, y):
2984            def add(x, y):
2985                return x + y
2986
2987            @functools.wraps(add)
2988            def wrapped_call(x, y):
2989                return add(x, y)
2990
2991            return wrapped_call(x, y)
2992
2993        x = torch.randn(3, 3)
2994        y = torch.randn(3, 3)
2995
2996        o = torch.compile(foo, fullgraph=True, backend="eager")(x, y)
2997        self.assertEqual(o, x + y)
2998
2999        def foo(x, y):
3000            def nested_call(x, y):
3001                def mul(x, y):
3002                    return x * y
3003
3004                @functools.wraps(mul)
3005                def double_nested_call(x, y):
3006                    return mul(x, y)
3007
3008                return double_nested_call(x, y)
3009
3010            return nested_call(x, y)
3011
3012        o = torch.compile(foo, fullgraph=True, backend="eager")(x, y)
3013        self.assertEqual(o, x * y)
3014
3015    def test_module_deepcopy(self):
3016        m1 = torch.nn.Sequential(
3017            torch.nn.Linear(10, 10),
3018            torch.nn.ReLU(),
3019            torch.nn.Linear(10, 10),
3020            torch.nn.ReLU(),
3021        )
3022        m2 = torch.nn.Sequential(
3023            torch.nn.Linear(10, 10),
3024            torch.nn.ReLU(),
3025            torch.nn.Linear(10, 10),
3026            torch.nn.ReLU(),
3027        )
3028
3029        def fn(m, x):
3030            m_copy = copy.deepcopy(m)
3031            return m_copy(x)
3032
3033        v = torch.randn(10)
3034        correct1 = fn(m1, v)
3035        correct2 = fn(m2, v)
3036        cnts = torch._dynamo.testing.CompileCounter()
3037        opt_fn = torch._dynamo.optimize(cnts)(fn)
3038        for _ in range(10):
3039            self.assertTrue(same(opt_fn(m1, v), correct1))
3040        for _ in range(10):
3041            self.assertTrue(same(opt_fn(m2, v), correct2))
3042        self.assertEqual(cnts.frame_count, 1)
3043        self.assertEqual(cnts.op_count, 4)
3044
3045    def test_type_copy(self):
3046        def fn(seq):
3047            a, b = seq
3048            return type(seq)([a + 1, b + 2, a + b])
3049
3050        args1 = [torch.randn(10), torch.randn(10)]
3051        args2 = (torch.randn(10), torch.randn(10))
3052        correct1 = fn(args1)
3053        correct2 = fn(args2)
3054        cnts = torch._dynamo.testing.CompileCounter()
3055        opt_fn = torch._dynamo.optimize(cnts)(fn)
3056        self.assertTrue(same(opt_fn(args1), correct1))
3057        self.assertTrue(same(opt_fn(args2), correct2))
3058        self.assertIsInstance(opt_fn(args1), list)
3059        self.assertIsInstance(opt_fn(args2), tuple)
3060        self.assertEqual(cnts.frame_count, 2)
3061        self.assertEqual(cnts.op_count, 6)
3062
3063    def test_setattr_mutation1(self):
3064        class MyObj:  # noqa: B903
3065            def __init__(self, a, b):
3066                self.a = a
3067                self.b = b
3068
3069        def fn(obj):
3070            obj.c = obj.a * obj.b + 1
3071            obj.b = obj.a * obj.c + 2
3072            obj.a = obj.b * obj.c + 3
3073            obj.c = obj.a * obj.b + 4
3074            obj.b = obj.a * obj.c + 5
3075            obj.a = obj.b * obj.c + 6
3076            return obj
3077
3078        x1 = torch.randn(10)
3079        x2 = torch.randn(10)
3080        obj1 = MyObj(x1, x2)
3081        obj2 = MyObj(x1, x2)
3082        fn(obj2)
3083        cnts = torch._dynamo.testing.CompileCounter()
3084        opt_fn = torch._dynamo.optimize(cnts)(fn)
3085        self.assertIs(opt_fn(obj1), obj1)
3086        self.assertTrue(same(obj1.a, obj2.a))
3087        self.assertTrue(same(obj1.b, obj2.b))
3088        self.assertTrue(same(obj1.c, obj2.c))
3089        self.assertEqual(cnts.frame_count, 1)
3090        self.assertEqual(cnts.op_count, 12)
3091
3092    def test_setattr_mutation2(self):
3093        class MyObj:
3094            def __init__(self, x):
3095                self.a = x + 1
3096                self.b = x + 2
3097
3098        def fn(x):
3099            x = x / 3.0
3100            obj = MyObj(x)
3101            obj.c = obj.a * obj.b + 1
3102            obj.b = obj.a * obj.c + 2
3103            obj.a = obj.b * obj.c + 3
3104            return obj
3105
3106        x1 = torch.randn(10)
3107        obj2 = fn(x1)
3108
3109        cnts = torch._dynamo.testing.CompileCounter()
3110        opt_fn = torch._dynamo.optimize(cnts)(fn)
3111        obj1 = opt_fn(x1)
3112        self.assertTrue(same(obj1.a, obj2.a))
3113        self.assertTrue(same(obj1.b, obj2.b))
3114        self.assertTrue(same(obj1.c, obj2.c))
3115        self.assertEqual(cnts.frame_count, 1)
3116        self.assertEqual(cnts.op_count, 9)
3117
3118    def test_setattr_mutation3(self):
3119        # TODO(jansel): dead code eliminate the object creation
3120        class MyObj:
3121            def __init__(self, x):
3122                super().__init__()
3123                self.a = x + 1
3124                self.b = x + 2
3125
3126        def fn(x):
3127            x = x / 3.0
3128            obj = MyObj(x)
3129            obj.c = obj.a * obj.b + 1
3130            obj.b = obj.a * obj.c + 2
3131            obj.a = obj.b * obj.c + 3
3132            return obj.a, obj.b, obj.c
3133
3134        x1 = torch.randn(10)
3135        obj2 = fn(x1)
3136
3137        cnts = torch._dynamo.testing.CompileCounter()
3138        opt_fn = torch._dynamo.optimize(cnts)(fn)
3139        obj1 = opt_fn(x1)
3140        self.assertTrue(same(obj1, obj2))
3141        self.assertEqual(cnts.frame_count, 1)
3142        self.assertEqual(cnts.op_count, 9)
3143
3144    def test_object_setattr(self):
3145        @dataclasses.dataclass
3146        class A:
3147            x: torch.Tensor
3148
3149        def fn1(x) -> None:
3150            a = A(x)
3151            object.__setattr__(a, "x", x + 2)
3152            return a
3153
3154        x1 = torch.randn(10)
3155        obj11 = fn1(x1.clone())
3156
3157        cnts = torch._dynamo.testing.CompileCounter()
3158        opt_fn1 = torch._dynamo.optimize(cnts, nopython=True)(fn1)
3159        obj12 = opt_fn1(x1.clone())
3160        self.assertTrue(same(obj11.x, x1 + 2))
3161        self.assertTrue(same(obj12.x, x1 + 2))
3162        self.assertTrue(same(obj11.x, obj12.x))
3163        self.assertEqual(cnts.frame_count, 1)
3164
3165        @dataclasses.dataclass(frozen=True)
3166        class B:
3167            x: torch.Tensor
3168
3169        def fn2(x) -> None:
3170            b = B(x)
3171            return b
3172
3173        x2 = torch.randn(10)
3174        obj21 = fn2(x2.clone())
3175
3176        cnts = torch._dynamo.testing.CompileCounter()
3177        opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
3178        obj22 = opt_fn2(x2.clone())
3179        self.assertTrue(same(obj21.x, x2))
3180        self.assertTrue(same(obj22.x, x2))
3181        self.assertTrue(same(obj21.x, obj22.x))
3182        self.assertEqual(cnts.frame_count, 0)
3183
3184        @dataclasses.dataclass(frozen=True)
3185        class C:
3186            x: torch.Tensor
3187
3188        def fn3(x) -> None:
3189            c = C(x)
3190            object.__setattr__(c, "x", x + 2)
3191            return c
3192
3193        x3 = torch.randn(10)
3194        obj31 = fn3(x3.clone())
3195
3196        cnts = torch._dynamo.testing.CompileCounter()
3197        opt_fn3 = torch._dynamo.optimize(cnts, nopython=True)(fn3)
3198        obj32 = opt_fn3(x3.clone())
3199        self.assertTrue(same(obj31.x, x3 + 2))
3200        self.assertTrue(same(obj32.x, x3 + 2))
3201        self.assertTrue(same(obj31.x, obj32.x))
3202        self.assertEqual(cnts.frame_count, 1)
3203
3204        @dataclasses.dataclass(frozen=True)
3205        class D:
3206            x: torch.Tensor
3207
3208            def __post_init__(self):
3209                object.__setattr__(self, "y", self.x + 2)
3210
3211        def fn4(x) -> None:
3212            d = D(x)
3213            return d
3214
3215        x4 = torch.randn(10)
3216        obj41 = fn4(x4.clone())
3217
3218        cnts = torch._dynamo.testing.CompileCounter()
3219        opt_fn4 = torch._dynamo.optimize(cnts, nopython=True)(fn4)
3220        obj42 = opt_fn4(x4.clone())
3221        self.assertTrue(same(obj41.x, x4))
3222        self.assertTrue(same(obj42.x, x4))
3223        self.assertTrue(same(obj41.x, obj42.x))
3224        self.assertTrue(same(obj41.y, x4 + 2))
3225        self.assertTrue(same(obj42.y, x4 + 2))
3226        self.assertTrue(same(obj41.y, obj42.y))
3227        self.assertEqual(cnts.frame_count, 1)
3228
3229    def test_user_defined_class_name(self):
3230        class MyClassFoo:
3231            pass
3232
3233        def fn1(a, b, c):
3234            tmp = MyClassFoo()
3235            if tmp.__class__.__name__ == "MyClassFoo":
3236                return a - b / c
3237
3238        torch._dynamo.testing.standard_test(self, fn=fn1, nargs=3)
3239
3240    def test_user_defined_class_python_type(self):
3241        class MyClass1:
3242            pass
3243
3244        class ExampleMeta(type):
3245            pass
3246
3247        class MyClass2(metaclass=ExampleMeta):
3248            pass
3249
3250        def fn(x, c):
3251            if isinstance(c, MyClass1):
3252                return x + 1
3253            elif isinstance(c, MyClass2):
3254                return x + 2
3255            else:
3256                return x + 3
3257
3258        x = torch.rand(3)
3259        opt_fn = torch._dynamo.optimize("eager")(fn)
3260        for c in [MyClass1, MyClass2]:
3261            ref = fn(x, c)
3262            res = opt_fn(x, c)
3263            self.assertTrue(same(ref, res))
3264
3265    def test_super_calling_with_metaclass(self):
3266        class ExampleMeta(type):
3267            pass
3268
3269        class MyClass1(metaclass=ExampleMeta):
3270            coeff = 4  # Force the constant guard to test source in guards
3271
3272            @classmethod
3273            def add(cls, x):
3274                return x + 1
3275
3276        class MyClass2(MyClass1):
3277            @classmethod
3278            def add(cls, x):
3279                torch._dynamo.graph_break()
3280                return x + super().add(x) + super().coeff
3281
3282        def fn(x, obj):
3283            return x + obj.add(x)
3284
3285        x = torch.rand(3)
3286        obj = MyClass2()
3287        opt_fn = torch._dynamo.optimize("eager")(fn)
3288        ref = fn(x, obj)
3289        res = opt_fn(x, obj)
3290        self.assertTrue(same(ref, res))
3291
3292    def test_usr_cls_staticmethod(self):
3293        class Foo:
3294            @staticmethod
3295            def bar(a, b):
3296                return a + b
3297
3298        def fn(a, b):
3299            return Foo.bar(a, b) - 1
3300
3301        torch._dynamo.testing.standard_test(self, fn=fn, nargs=2)
3302
3303    def test_usr_cls_classmethod(self):
3304        class Foo:
3305            @classmethod
3306            def bar(cls, a, b):
3307                return a + b
3308
3309        def fn(a, b):
3310            return Foo.bar(a, b) - 1
3311
3312        torch._dynamo.testing.standard_test(self, fn=fn, nargs=2)
3313
3314    def test_dunder_methods(self):
3315        class Foo:
3316            def __init__(self, val):
3317                super().__init__()
3318                self.val = val
3319
3320            def __add__(self, other):
3321                return Foo(self.val + other.val)
3322
3323            def __mul__(self, other):
3324                return Foo(self.val * other.val)
3325
3326            def __truediv__(self, other):
3327                return Foo(self.val / other.val)
3328
3329            def __sub__(self, other):
3330                return Foo(self.val - other.val)
3331
3332        def fn(a, b, c):
3333            return Foo(a) + Foo(b) * Foo(c) / Foo(a) - Foo(b)
3334
3335        torch._dynamo.testing.standard_test(self, fn=fn, nargs=3, expected_ops=4)
3336
3337    def test_function_annotation(self):
3338        class Variable:
3339            pass
3340
3341        def fn(x):
3342            x = x / 3.0
3343
3344            def inner(y: typing.List[Variable]):
3345                return x + 1
3346
3347            return inner
3348
3349        x1 = torch.randn(10)
3350        obj2 = fn(x1)([])
3351
3352        cnts = torch._dynamo.testing.CompileCounter()
3353        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
3354        opt_fn_inner = torch._dynamo.optimize_assert(cnts)(opt_fn(x1))
3355        obj1 = opt_fn_inner([])
3356        self.assertTrue(same(obj1, obj2))
3357        self.assertEqual(cnts.frame_count, 2)
3358        self.assertEqual(cnts.op_count, 2)
3359
3360    def test_nested_closure(self):
3361        v0 = torch.randn(10)
3362
3363        def fn1():
3364            v1 = torch.randn(10)
3365
3366            def fn2(*args, **kwargs):
3367                assert len(args) == 1
3368                assert len(kwargs) == 1
3369                v2 = torch.randn(10) + args[0] + kwargs["b"]
3370
3371                def fn3(v3=torch.randn(10)):
3372                    def fn4():
3373                        return v0 + v1 + v2 + v3 + 1
3374
3375                    return fn4
3376
3377                return fn3
3378
3379            return fn2(1, b=2)()
3380
3381        cnts = torch._dynamo.testing.CompileCounter()
3382        opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1)
3383        tmp1 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
3384        tmp2 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
3385        self.assertTrue(tmp1().shape, (10,))
3386        self.assertTrue(same(tmp1(), tmp1()))
3387        self.assertFalse(same(tmp1(), tmp2()))
3388        self.assertEqual(cnts.frame_count, 2)
3389        self.assertEqual(cnts.op_count, 9)
3390
3391    def test_nested_closure_mutation(self):
3392        def fn1():
3393            v1 = torch.randn(10)
3394
3395            def fn2():
3396                v2 = torch.randn(10)
3397
3398                def fn3():
3399                    nonlocal v1, v2
3400                    v1 += 1
3401                    v2 += 2
3402                    return v1 + v2
3403
3404                return fn3
3405
3406            rv = fn2()
3407            rv()
3408            rv()
3409            return rv
3410
3411        torch.manual_seed(9000)
3412        counter1 = fn1()
3413        result1 = [counter1(), counter1(), counter1()]
3414
3415        torch.manual_seed(9000)
3416        cnts = torch._dynamo.testing.CompileCounter()
3417        opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1)
3418        counter2 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
3419        result2 = [counter2(), counter2(), counter2()]
3420        result1.append(counter1())
3421        result2.append(counter2())
3422
3423        self.assertTrue(same(result1, result2))
3424        self.assertEqual(cnts.frame_count, 2)
3425        self.assertEqual(cnts.op_count, 11)
3426
3427    def test_write_to_closures_in_inlining(self):
3428        out = []
3429        for use_dynamo in [False, True]:
3430
3431            def make_counter():
3432                x = torch.randn(10)
3433
3434                def counter():
3435                    nonlocal x
3436                    x = x + 1
3437                    return x
3438
3439                return counter
3440
3441            torch.manual_seed(0)
3442            counter = make_counter()
3443            if not use_dynamo:
3444                out.append(counter() + counter())
3445            else:
3446                cnts = torch._dynamo.testing.CompileCounter()
3447
3448                @torch._dynamo.optimize(cnts, nopython=True)
3449                def fn(counter):
3450                    return counter() + counter()
3451
3452                out.append(fn(counter))
3453                self.assertEqual(cnts.frame_count, 1)
3454                self.assertEqual(cnts.op_count, 3)
3455                self.assertFalse(same(counter() + counter(), out[-1]))
3456
3457        self.assertTrue(same(out[0], out[1]))
3458
3459    def test_closure_out_of_scope_cell(self):
3460        cell1 = torch.rand(1).item()
3461        cell2 = torch.rand(3, 3)
3462
3463        def indirect():
3464            return direct()
3465
3466        def direct():
3467            def inner():
3468                return cell1 + 1, cell2 + 3
3469
3470            return inner()
3471
3472        cnts = torch._dynamo.testing.CompileCounter()
3473        opt_fn = torch._dynamo.optimize(cnts)(indirect)
3474        result1, result2 = opt_fn()
3475        self.assertAlmostEqual(cell1 + 1, result1)
3476        self.assertTrue(torch.allclose(cell2 + 3, result2))
3477        self.assertEqual(cnts.frame_count, 1)
3478        self.assertEqual(cnts.op_count, 1)
3479
3480    def test_closure_out_of_scope_cell_with_mutation(self):
3481        cell1 = torch.rand(1).item()
3482        orig1 = cell1
3483        cell2 = torch.rand(3, 3)
3484        orig2 = cell2.clone()
3485
3486        def indirect():
3487            return direct()
3488
3489        def direct():
3490            def inner():
3491                nonlocal cell1, cell2
3492                x = cell2 + 1
3493                cell1 += 1
3494                cell2 += 10
3495                x = x + cell2
3496                return cell1, cell2, x
3497
3498            return inner()
3499
3500        cnts = torch._dynamo.testing.CompileCounter()
3501        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(indirect)
3502        for i in range(1, 4):
3503            result1, result2, _ = opt_fn()
3504            self.assertAlmostEqual(orig1 + 1 * i, result1)
3505            self.assertTrue(torch.allclose(orig2 + 10 * i, result2))
3506            self.assertEqual(cnts.frame_count, 1)
3507            self.assertEqual(cnts.op_count, 3)
3508            cnts.clear()
3509
3510    def test_closure_with_mutation_and_graph_break(self):
3511        def fn():
3512            x = torch.zeros(1)
3513
3514            def subfunc():
3515                x[0] = backup
3516
3517            if x[0] >= -1e5:
3518                pass
3519
3520            backup = 1
3521            subfunc()
3522            return x
3523
3524        cnts = torch._dynamo.testing.CompileCounter()
3525        opt_fn = torch._dynamo.optimize(cnts)(fn)
3526        expected = fn()
3527        actual = opt_fn()
3528        self.assertTrue(same(expected, actual))
3529        self.assertEqual(cnts.frame_count, 2)
3530
3531    def test_closure_out_of_scope_cell_with_cond(self):
3532        # Test closure with out-of-scope cell variable, used in a cond
3533        # where the two branches read different closure variables
3534        from functorch.experimental.control_flow import cond
3535
3536        def g(x):
3537            return x
3538
3539        class ModuleCondDeep(torch.nn.Module):
3540            def forward(self, pred, x):
3541                return self._indirection(pred, x)
3542
3543            def _indirection(self, pred, x):
3544                return self.indirection(pred, x)
3545
3546            def indirection(self, pred, x):
3547                def true_fn(y):
3548                    return y + 2
3549
3550                def false_fn(y):
3551                    return y - 2
3552
3553                def shallow(x):
3554                    return x * 2
3555
3556                def deep(x):
3557                    # y = g(x)
3558                    y = x
3559                    return cond(
3560                        x[0][0] > 0,
3561                        true_fn,
3562                        false_fn,
3563                        [y],
3564                    )
3565
3566                return cond(pred, shallow, deep, [x])
3567
3568        mod = ModuleCondDeep()
3569        opt_mod = torch._dynamo.optimize("eager")(mod)
3570        inp = torch.randn(3, 3)
3571        exp1 = mod(torch.tensor(False), inp)
3572        actual1 = opt_mod(torch.tensor(False), inp)
3573        exp2 = mod(torch.tensor(True), inp)
3574        actual2 = opt_mod(torch.tensor(True), inp)
3575        self.assertTrue(torch.allclose(exp1, actual1))
3576        self.assertTrue(torch.allclose(exp2, actual2))
3577
3578    def test_top_package_import(self):
3579        def fn(x):
3580            import torch.fx
3581
3582            assert not isinstance(x, torch.fx.Proxy)
3583            return torch.sin(x)
3584
3585        x = torch.randn(4, 5)
3586        ref = fn(x)
3587        cnts = torch._dynamo.testing.CompileCounter()
3588        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
3589        res = opt_fn(x)
3590        self.assertTrue(same(ref, res))
3591
3592    def test_typing_typevar(self):
3593        def fn(x):
3594            def sumt(y: torch.Tensor) -> torch.Tensor:
3595                return torch.sum(y)
3596
3597            def foo(c: typing.Callable[[T], T], y: T) -> T:
3598                return c(y)
3599
3600            return foo(sumt, x)
3601
3602        x = torch.randn(3)
3603        ref = fn(x)
3604        cnts = torch._dynamo.testing.CompileCounter()
3605        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
3606        res = opt_fn(x)
3607        self.assertTrue(same(ref, res))
3608        self.assertEqual(cnts.frame_count, 1)
3609
3610    def test_typing_union_and_optional(self):
3611        def fn(x):
3612            a = torch.jit.annotate(typing.Dict[str, typing.Optional[torch.Tensor]], {})
3613            b = torch.jit.annotate(
3614                typing.Dict[str, typing.Union[torch.Tensor, None]], {}
3615            )
3616            return a, b, x + 1
3617
3618        x = torch.randn(3)
3619        ref = fn(x)
3620        opt_fn = torch._dynamo.optimize("eager", nopython=False)(fn)
3621        res = opt_fn(x)
3622        self.assertTrue(same(ref, res))
3623
3624    def test_optimize_on_module(self):
3625        class MockModule(torch.nn.Module):
3626            def __init__(self):
3627                super().__init__()
3628                self.relu = torch.nn.ReLU()
3629
3630            def custom_member(self):
3631                # Just for checking that Dynamo returned mod object can redirect
3632                # to this method
3633                pass
3634
3635            def forward(self, x):
3636                return self.relu(x)
3637
3638        cnts1 = torch._dynamo.testing.CompileCounter()
3639        mod = MockModule()
3640        optimized_mod = torch._dynamo.optimize(cnts1, nopython=True)(mod)
3641
3642        a = torch.randn(10)
3643        ref = mod(a)
3644        res = optimized_mod(a)
3645
3646        optimized_mod.custom_member()
3647
3648        self.assertTrue(same(ref, res))
3649
3650    def test_nested_optimize_decorator(self):
3651        cnts2 = torch._dynamo.testing.CompileCounter()
3652        cnts3 = torch._dynamo.testing.CompileCounter()
3653
3654        @torch._dynamo.run()
3655        def fn1(x):
3656            return torch.sin(x) * 10
3657
3658        @torch._dynamo.optimize(cnts2, nopython=True)
3659        def fn2(x):
3660            return fn1(x) + 1
3661
3662        @torch._dynamo.optimize(cnts3, nopython=True)
3663        def fn3(x):
3664            return torch.relu(fn2(x))
3665
3666        fn3(torch.randn(4, 5))
3667        self.assertEqual(cnts2.frame_count, 0)
3668        self.assertEqual(cnts3.frame_count, 1)
3669        self.assertEqual(cnts3.op_count, 4)
3670
3671    def test_nested_optimize_run(self):
3672        cnts = torch._dynamo.testing.CompileCounter()
3673
3674        @torch._dynamo.optimize(cnts, nopython=True)
3675        def fn(x):
3676            return torch.relu(torch.cos(x) + torch.sin(x))
3677
3678        fn(torch.randn(4))
3679        self.assertEqual(cnts.frame_count, 1)
3680
3681        fn(torch.randn(4, 4))
3682        self.assertEqual(cnts.frame_count, 2)
3683
3684        # Test that run works on a decorated fn
3685        fn = torch._dynamo.run(fn)
3686        fn(torch.randn(4, 4, 4))
3687        self.assertEqual(cnts.frame_count, 2)
3688
3689    def test_nested_optimize(self):
3690        cnts1 = torch._dynamo.testing.CompileCounter()
3691        cnts2 = torch._dynamo.testing.CompileCounter()
3692
3693        def fn(x):
3694            return torch.relu(torch.cos(x) + torch.sin(x))
3695
3696        fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn)
3697        fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1)
3698
3699        # The first optimize in the nesting should be ignored
3700        fn2(torch.randn(4))
3701        self.assertEqual(cnts2.frame_count, 1)
3702        self.assertEqual(cnts1.frame_count, 0)
3703
3704        # Since the fn code object is already compiled, calling fn1 should
3705        # directly call the compiled_fn callable.
3706        torch._dynamo.run()(fn1)(torch.randn(4))
3707        self.assertEqual(cnts1.frame_count, 0)
3708
3709        # Test same behavior by reversing the calls
3710        torch._dynamo.reset()
3711        cnts1 = torch._dynamo.testing.CompileCounter()
3712        cnts2 = torch._dynamo.testing.CompileCounter()
3713        fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn)
3714        fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1)
3715        fn1(torch.randn(4))
3716        self.assertEqual(cnts1.frame_count, 1)
3717        torch._dynamo.run()(fn2)(torch.randn(4))
3718        self.assertEqual(cnts2.frame_count, 0)
3719
3720    def test_torch_size(self):
3721        cnts = torch._dynamo.testing.CompileCounter()
3722
3723        def fn(x):
3724            output_size = torch.Size([10, 10])
3725            x = x.view(*output_size)
3726            return (x,)
3727
3728        x = torch.randn(100, requires_grad=True)
3729        x_clone = x.clone()
3730        ref = fn(x)
3731
3732        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3733        res = opt_fn(x_clone)
3734
3735        self.assertTrue(same(ref, res))
3736
3737    def test_torch_size_numel(self):
3738        cnts = torch._dynamo.testing.CompileCounter()
3739
3740        def fn():
3741            return torch.Size([10, 8]).numel()
3742
3743        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3744        num = torch.Size([10, 8]).numel()
3745        self.assertEqual(opt_fn(), num)
3746
3747    def test_torch_size_numel_dynamic(self):
3748        cnts = torch._dynamo.testing.CompileCounter()
3749
3750        def fn(x):
3751            return x.size().numel()
3752
3753        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3754        x = torch.rand(10, 1, 8, 1)
3755        expect = fn(x)
3756        self.assertEqual(opt_fn(x), expect)
3757
3758    def test_shape_type(self):
3759        cnts = torch._dynamo.testing.CompileCounter()
3760
3761        def fn(x):
3762            return x + (type(x.shape) == torch.Size)
3763
3764        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3765        x = torch.zeros(())
3766        self.assertEqual(opt_fn(x), fn(x))
3767
3768    def test_size_dim(self):
3769        cnts = torch._dynamo.testing.CompileCounter()
3770
3771        def fn(x, dim):
3772            return x.size(dim=dim)
3773
3774        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3775        x = torch.empty([4, 9, 8])
3776        self.assertEqual(opt_fn(x, 1), 9)
3777        self.assertEqual(opt_fn(x, -2), 9)
3778
3779    def test_stride_dim(self):
3780        cnts = torch._dynamo.testing.CompileCounter()
3781
3782        def fn(x, dim):
3783            return x.stride(dim=dim)
3784
3785        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3786        x = torch.empty([4, 9, 8])
3787        self.assertEqual(opt_fn(x, 0), 72)
3788        self.assertEqual(opt_fn(x, -2), 8)
3789
3790    def test_torch_seed(self):
3791        from torch._dynamo.utils import counters
3792
3793        cnts = torch._dynamo.testing.CompileCounter()
3794        counters.clear()
3795
3796        def fn(x):
3797            attention_seed = int(torch.seed() % sys.maxsize)
3798            torch.manual_seed(attention_seed)
3799            return (x,)
3800
3801        x = torch.randn(10, requires_grad=True)
3802        ref = fn(x)
3803
3804        # Python code is needed here, since torch.manual_seed graph-breaks.
3805        # Refs: https://github.com/pytorch/pytorch/issues/107187
3806        opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
3807        res = opt_fn(x)
3808
3809        self.assertTrue(same(ref, res))
3810        # Only the torch.seed call is turned into an FX graph.
3811        self.assertEqual(cnts.op_count, 1)
3812        self.assertEqual(cnts.frame_count, 1)
3813        # Graph breaks at manual_seed.
3814        self.assertEqual(len(counters["graph_break"]), 1)
3815
3816    def test_is_tensor_like(self):
3817        cnts = torch._dynamo.testing.CompileCounter()
3818
3819        def f(x):
3820            if torch.overrides.is_tensor_like(x):
3821                return (x * 2,)
3822            return (torch.ones(10) + x,)
3823
3824        x = torch.randn(10)
3825        ref0 = f(x)
3826        ref1 = f(4)
3827        opt_f = torch._dynamo.optimize(cnts, nopython=True)(f)
3828        res0 = opt_f(x)
3829        res1 = opt_f(4)
3830        self.assertTrue(same(ref0, res0))
3831        self.assertTrue(same(ref1, res1))
3832
3833    def test_is_tensor_like2(self):
3834        class MyTensor:
3835            @classmethod
3836            def __torch_function__(cls, func, types, args=(), kwargs=None):
3837                if kwargs is None:
3838                    kwargs = {}
3839
3840                if func is torch.max:
3841                    return torch.tensor(123)
3842                return func(*args, **kwargs)
3843
3844        def fn(x):
3845            if torch.overrides.is_tensor_like(x):
3846                return torch.max(x)
3847            else:
3848                return torch.zeros(1)
3849
3850        x = MyTensor()
3851        ref0 = fn(x)
3852        ref1 = fn(4)
3853        opt_fn = torch._dynamo.optimize("eager")(fn)
3854        res0 = opt_fn(x)
3855        res1 = opt_fn(4)
3856        self.assertTrue(same(ref0, res0))
3857        self.assertTrue(same(ref1, res1))
3858
3859    def test_tensor_data(self):
3860        def fn(x, y):
3861            return x[y.data]
3862
3863        x = torch.rand(8)
3864        y = torch.ones(8).to(torch.int)
3865        ref = fn(x, y)
3866        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
3867        res = opt_fn(x, y)
3868        self.assertTrue(same(ref, res))
3869
3870    def test_tensor_layout(self):
3871        def fn(x):
3872            return torch.zeros(
3873                [x.size()[0], x.size()[1]],
3874                dtype=x.dtype,
3875                layout=x.layout,
3876                device=x.device,
3877            )
3878
3879        x = torch.rand(2, 3)
3880        ref = fn(x)
3881        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
3882        res = opt_fn(x)
3883        self.assertTrue(same(ref, res))
3884
3885    def test_version_ci(self):
3886        # temporary test to check that the ci torch version is set correctly
3887        self.assertTrue(hasattr(torch, "_subclasses"))
3888
3889    @unittest.skipIf(not TEST_CUDA, "requires cuda")
3890    def test_rand(self):
3891        cnts = torch._dynamo.testing.CompileCounter()
3892        device = "cuda"
3893
3894        def fn():
3895            return torch.randn(10, device=device)
3896
3897        torch.manual_seed(10)
3898        ref_run1 = fn()
3899
3900        torch.manual_seed(10)
3901        ref_run2 = fn()
3902        self.assertTrue(same(ref_run1, ref_run2))
3903
3904        torch.manual_seed(10)
3905        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3906        res = opt_fn()
3907
3908        self.assertTrue(same(res, ref_run1))
3909
3910    def test_slice_input(self):
3911        cnts = torch._dynamo.testing.CompileCounter()
3912
3913        def getitem(a, idx):
3914            if isinstance(idx, slice):
3915                return (
3916                    torch.zeros(1),
3917                    a[idx]
3918                    + [
3919                        100,
3920                    ],
3921                )
3922            else:
3923                return (torch.zeros(1), a[idx])
3924
3925        layers = list(range(10))
3926        ref0 = getitem(layers, slice(0, 2, 1))
3927        ref1 = getitem(layers, 2)
3928        ref2 = getitem(layers, slice(3, 8, 2))
3929        opt_getitem = torch._dynamo.optimize(cnts, nopython=True)(getitem)
3930        res0 = opt_getitem(layers, slice(0, 2, 1))
3931        res1 = opt_getitem(layers, 2)
3932        res2 = opt_getitem(layers, slice(3, 8, 2))
3933
3934        self.assertTrue(ref0 == res0)
3935        self.assertTrue(ref1 == res1)
3936        self.assertTrue(ref2 == res2)
3937
3938    def test_grad(self):
3939        cnts = torch._dynamo.testing.CompileCounter()
3940
3941        def fn(a, b):
3942            out = a * b
3943            out.sum().backward()
3944            real_out = torch.sigmoid(a.grad + b)
3945            return real_out
3946
3947        inps = [torch.randn(4, requires_grad=True) for _ in range(2)]
3948        for inp in inps:
3949            inp.grad = None
3950        ref = fn(*inps)
3951
3952        for inp in inps:
3953            inp.grad = None
3954        opt_fn = torch._dynamo.optimize(cnts)(fn)
3955        res = opt_fn(*inps)
3956
3957        self.assertTrue(same(ref, res))
3958
3959    @torch._dynamo.config.patch(guard_nn_modules=True)
3960    def test_source_non_input_grad_access(self):
3961        # This test creates a model, and accesses the grads
3962        # from its parameter. This means that within dynamo,
3963        # the tensor we are reading the grad from HAS a source,
3964        # but is not known to graphargs.
3965        cnts = torch._dynamo.testing.CompileCounter()
3966
3967        class TrivialModel(torch.nn.Module):
3968            def __init__(self):
3969                super(TrivialModel, self).__init__()
3970                self.linear = torch.nn.Linear(2, 1)
3971
3972            def forward(self, x):
3973                return self.linear(x)
3974
3975        def fn(a, b):
3976            outs = []
3977            for param in model.parameters():
3978                outs.append(torch.ones(param.grad.size()))
3979            return outs, param.grad + 1
3980
3981        model = TrivialModel()
3982        # Eager
3983        a = torch.ones([2, 2], requires_grad=True)
3984        b = torch.ones([2, 2])
3985        out = model(a)
3986        out_sum = out.sum()
3987        out_sum.backward()
3988        ref = fn(a, b)
3989
3990        # Compiled
3991        model = TrivialModel()
3992        a = torch.ones([2, 2], requires_grad=True)
3993        b = torch.ones([2, 2])
3994        out = model(a)
3995        out_sum = out.sum()
3996        out_sum.backward()
3997
3998        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3999        res = opt_fn(a, b)
4000
4001        self.assertTrue(same(ref, res))
4002        self.assertEqual(cnts.frame_count, 1)
4003        self.assertEqual(cnts.op_count, 3)
4004
4005    def test_intermediary_tensor_grad_access(self):
4006        # This test creates a model, and accesses the grads
4007        # from its parameters and an entirely intermediary tensor.
4008        cnts = torch._dynamo.testing.CompileCounter()
4009
4010        def fn(a, b):
4011            intermediary = torch.ones(2, 2)
4012            c = a + intermediary
4013            outs = []
4014            outs.append(intermediary.grad)
4015            return outs
4016
4017        # Eager
4018        a = torch.ones([2, 2], requires_grad=True)
4019        b = torch.ones([2, 2])
4020        ref = fn(a, b)
4021
4022        # Compiled
4023        a = torch.ones([2, 2], requires_grad=True)
4024        b = torch.ones([2, 2])
4025        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4026        res = opt_fn(a, b)
4027        self.assertTrue(same(ref, res))
4028        self.assertEqual(cnts.frame_count, 1)
4029        self.assertEqual(cnts.op_count, 2)
4030
4031    def test_clone_sparse_input(self):
4032        for layout in [
4033            torch.sparse_coo,
4034            torch.sparse_csr,
4035            torch.sparse_csc,
4036            torch.sparse_bsr,
4037            torch.sparse_bsc,
4038        ]:
4039            for sparse_input in self.generate_simple_inputs(
4040                layout,
4041                device="cpu",
4042                dtype=torch.float64,
4043                index_dtype=torch.int64,
4044            ):
4045                # Invoke the dynamo clone input method directly.
4046                sparse_copy = torch._dynamo.utils.clone_input(sparse_input)
4047                # Make sure sparse clone is successful.
4048                self.assertEqual(sparse_input, sparse_copy)
4049
4050    def test_tensor_is_contiguous(self):
4051        def fn(x):
4052            input = torch.randn((1, 16, 1, 1))
4053            weight = torch.randn((8, 16, 3, 3))
4054            weight = weight.to(memory_format=x)
4055            output = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
4056            return output.is_contiguous(memory_format=x)
4057
4058        opt_fn = torch._dynamo.optimize("eager")(fn)
4059        for x in [torch.contiguous_format, torch.channels_last]:
4060            self.assertEqual(fn(x), opt_fn(x))
4061
4062    def test_python_slice(self):
4063        def f1(input):
4064            y = 0
4065            for i, x in enumerate(input[2:], 1):
4066                y = y + x
4067            return y
4068
4069        def f2(input):
4070            y = 0
4071            for i, x in enumerate(input.shape[2:], 1):
4072                y = y + x
4073            return y
4074
4075        cnts = torch._dynamo.testing.CompileCounter()
4076        opt_f1 = torch._dynamo.optimize(cnts)(f1)
4077        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4078        res1 = opt_f1([1, 2, 3, 5])
4079        res2 = opt_f2(torch.rand([2, 3, 4, 5]))
4080
4081        self.assertEqual(res1, 8)
4082        self.assertEqual(res2, 9)
4083
4084    def test_enum_as_dict_key(self):
4085        class MyEnum(enum.Enum):
4086            FOO = 10
4087            BAR = 20
4088
4089        def fn(x):
4090            y = x + 2
4091            z = {
4092                MyEnum.FOO: torch.tensor(1),
4093                MyEnum.BAR: 10,
4094                "MyEnum.BAR": torch.tensor(8),
4095                5: torch.rand(3),
4096            }
4097            torch._dynamo.graph_break()
4098            a = z[MyEnum.FOO] + z["MyEnum.BAR"]
4099            b = y * 2
4100            return a, b
4101
4102        cnts = torch._dynamo.testing.CompileCounter()
4103        opt_fn = torch._dynamo.optimize(cnts)(fn)
4104        for _ in range(10):
4105            x = torch.rand(3)
4106            ref = fn(x)
4107            res = opt_fn(x)
4108            self.assertTrue(same(ref, res))
4109        self.assertEqual(cnts.frame_count, 2)
4110
4111    def test_enum_as_dict_key_with_overloaded_str(self):
4112        class MyEnum(enum.Enum):
4113            FOO = 10
4114            BAR = 20
4115
4116            def __str__(self):
4117                return self.value
4118
4119        def fn(x):
4120            y = x + 2
4121            z = {
4122                MyEnum.FOO: torch.tensor(1),
4123                MyEnum.BAR: 10,
4124                "MyEnum.BAR": torch.tensor(8),
4125                5: torch.rand(3),
4126            }
4127            torch._dynamo.graph_break()
4128            a = z[MyEnum.FOO] + z["MyEnum.BAR"]
4129            b = y * 2
4130            return a, b
4131
4132        cnts = torch._dynamo.testing.CompileCounter()
4133        opt_fn = torch._dynamo.optimize(cnts)(fn)
4134        for _ in range(10):
4135            x = torch.rand(3)
4136            ref = fn(x)
4137            res = opt_fn(x)
4138            self.assertTrue(same(ref, res))
4139        self.assertEqual(cnts.frame_count, 2)
4140
4141    def test_const_dict_variable_python_type(self):
4142        from torch._dynamo.variables import ConstantVariable, ConstDictVariable
4143
4144        make_key = ConstantVariable.create
4145
4146        d1 = {
4147            make_key("a"): ConstantVariable.create(10),
4148            make_key("b"): ConstantVariable.create(20),
4149        }
4150        d2 = collections.OrderedDict(
4151            [
4152                (make_key("x"), ConstantVariable.create(12)),
4153                (make_key("y"), ConstantVariable.create(22)),
4154            ]
4155        )
4156        self.assertEqual(ConstDictVariable(d1).python_type(), dict)
4157        self.assertEqual(
4158            ConstDictVariable(d2, collections.OrderedDict).python_type(),
4159            collections.OrderedDict,
4160        )
4161
4162    def test_builtin_subclasses_as_method_on_class_type(self):
4163        class Foo:
4164            def __init__(self, name):
4165                self.ame_ = name
4166
4167            def get_name(self):
4168                return "Foo " + self.name_
4169
4170        class Bar(Foo):
4171            def __init__(self, name):
4172                self.name_ = name
4173
4174            def get_name(self):
4175                return "Bar " + self.name_
4176
4177        class Baz(Foo):
4178            def __init__(self, name):  # noqa: B903
4179                self.name_ = name
4180
4181            def get_name(self):
4182                return "Baz " + self.name_
4183
4184        subs_of_foo_reg = Foo.__subclasses__()
4185
4186        counter = CompileCounter()
4187
4188        @torch._dynamo.optimize_assert(counter)
4189        def fn():
4190            return Foo.__subclasses__()
4191
4192        subs_of_foo_optim = fn()
4193
4194        self.assertEqual(len(subs_of_foo_reg), 2)
4195        self.assertEqual(subs_of_foo_reg, subs_of_foo_optim)
4196
4197    def test_builtin_subclasses_as_method_on_var(self):
4198        class Foo:
4199            def __init__(self, name):
4200                self.name_ = name
4201
4202            def get_name(self):
4203                return "Foo " + self.name_
4204
4205        class Bar(Foo):
4206            def __init__(self, name):
4207                self.name_ = name
4208
4209            def get_name(self):
4210                return "Bar " + self.name_
4211
4212        class Baz(Bar):
4213            def __init__(self, name):
4214                self.name_ = name
4215
4216            def get_name(self):
4217                return "Baz " + self.name_
4218
4219        subs_of_foo_reg = Foo.__subclasses__()
4220        sub_of_foo_subclass_var_reg = subs_of_foo_reg[0].__subclasses__()
4221
4222        sub_of_foo_subclass_var_optim = list()
4223        counter = CompileCounter()
4224
4225        @torch._dynamo.optimize_assert(counter)
4226        def fn():
4227            return Foo.__subclasses__()
4228
4229        @torch._dynamo.optimize_assert(counter)
4230        def fn_single(subs_of_foo_optim):
4231            return subs_of_foo_optim[0].__subclasses__()
4232
4233        subs_of_foo_optim = fn()
4234        sub_of_foo_subclass_var_optim = fn_single(subs_of_foo_optim)
4235
4236        self.assertEqual(len(sub_of_foo_subclass_var_optim), 1)
4237        self.assertEqual(sub_of_foo_subclass_var_optim, sub_of_foo_subclass_var_reg)
4238
4239    def test_builtin_str_on_user_defined_function(self):
4240        def another_fn():
4241            pass
4242
4243        def fn():
4244            return "another_fn" in str(another_fn)
4245
4246        opt_fn = torch._dynamo.optimize(nopython=True)(fn)
4247        self.assertTrue(opt_fn())
4248
4249    def test_enum_no_graphbreaks(self):
4250        class Foo(enum.Enum):
4251            FOO = 0
4252            BAR = 1
4253
4254        def fn(x, foo):
4255            if foo is Foo.FOO:
4256                x = torch.add(x, 1.0)
4257            x = torch.mul(x, 1.0)
4258            return x
4259
4260        x = torch.randn(1)
4261        cnts = torch._dynamo.testing.CompileCounter()
4262        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4263        opt_fn(x, Foo.FOO)
4264        self.assertEqual(cnts.op_count, 2)
4265
4266        torch._dynamo.reset()
4267        cnts = torch._dynamo.testing.CompileCounter()
4268        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4269        opt_fn(x, Foo.BAR)
4270        self.assertEqual(cnts.op_count, 1)
4271
4272    def test_repeat_interleave_graphbreaks(self):
4273        def fn_no_breaks(x):
4274            # no breaks on self_int
4275            x += 1
4276            x = torch.repeat_interleave(x, 2, 3)
4277            x += 1
4278            return x
4279
4280        def fn_has_breaks(x):
4281            # breaks on self_Tensor
4282            x += 1
4283            x = torch.repeat_interleave(x, torch.tensor(2), 3)
4284            x += 1
4285            return x
4286
4287        x = torch.randn([4, 16, 1, 64])
4288
4289        cnts = torch._dynamo.testing.CompileCounter()
4290        opt_fn = torch._dynamo.optimize(cnts)(fn_no_breaks)
4291        opt_fn(x)
4292        self.assertEqual(cnts.frame_count, 1)
4293
4294        torch._dynamo.reset()
4295        cnts = torch._dynamo.testing.CompileCounter()
4296        opt_fn = torch._dynamo.optimize(cnts)(fn_has_breaks)
4297        opt_fn(x)
4298        self.assertEqual(cnts.frame_count, 2)
4299
4300    def test_id_guarded_object(self):
4301        class UDO:
4302            @torch.compile(backend="eager")
4303            def call(self, x, ref_id):
4304                self_id = id(self)
4305                if self_id == ref_id:
4306                    x = torch.mul(x, 1.0)
4307                else:
4308                    x = torch.mul(x, 0)
4309                return x
4310
4311        # Make sure we do recompile when id(self) is executed on
4312        # different self objects.
4313        x = torch.ones(2)
4314        obj1 = UDO()
4315        obj1_id = id(obj1)
4316        self.assertEqual(obj1.call(x, obj1_id), torch.ones(2))
4317
4318        obj2 = UDO()
4319        # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails.
4320        self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2))
4321
4322    def test_id_guarded_module(self):
4323        class M(torch.nn.Module):
4324            def forward(self, x, ref_id):
4325                self_id = id(self)
4326                if self_id == ref_id:
4327                    x = torch.mul(x, 1.0)
4328                else:
4329                    x = torch.mul(x, 0)
4330                return x
4331
4332        cnts = torch._dynamo.testing.CompileCounter()
4333
4334        # Make sure we do recompile when id(self) is executed on
4335        # different self objects.
4336        x = torch.ones(2)
4337        m1 = M()
4338        m1_id = id(m1)
4339        opt_m1 = torch._dynamo.optimize(cnts, nopython=True)(m1)
4340        self.assertEqual(opt_m1(x, m1_id), torch.ones(2))
4341        self.assertEqual(opt_m1(x, m1_id), torch.ones(2))
4342
4343        self.assertEqual(cnts.frame_count, 1)
4344        self.assertEqual(cnts.op_count, 1)
4345
4346        m2 = M()
4347        opt_m2 = torch._dynamo.optimize(cnts, nopython=True)(m2)
4348        # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails.
4349        self.assertEqual(opt_m2(x, m1_id), torch.zeros(2))
4350        self.assertEqual(cnts.frame_count, 2)
4351        self.assertEqual(cnts.op_count, 2)
4352
4353    def test_id_of_nn_module(self):
4354        class M(torch.nn.Module):
4355            def forward(self, x, ref_id):
4356                self_id = id(self)
4357                if self_id == ref_id:
4358                    x = torch.mul(x, 1.0)
4359                x = torch.add(x, 1.0)
4360                return x
4361
4362        m = M().eval()
4363        data = torch.randn(1)
4364        cnts = torch._dynamo.testing.CompileCounter()
4365        correct_ref_id = id(m)
4366        opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
4367        opt_m(data, correct_ref_id)
4368        # Extra op is the recorded equality test (although once
4369        # the trace is flattened this is dead!)
4370        if torch._dynamo.config.assume_static_by_default:
4371            self.assertExpectedInline(cnts.op_count, """2""")
4372        else:
4373            self.assertExpectedInline(cnts.op_count, """2""")
4374
4375        torch._dynamo.reset()
4376        cnts = torch._dynamo.testing.CompileCounter()
4377        incorrect_ref_id = id(m) + 1
4378        opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
4379        opt_m(data, incorrect_ref_id)
4380        if torch._dynamo.config.assume_static_by_default:
4381            self.assertExpectedInline(cnts.op_count, """1""")
4382        else:
4383            self.assertExpectedInline(cnts.op_count, """1""")
4384
4385    def test_inline_func_jump_on_tensor_condition(self):
4386        def f1(input):
4387            if input == 0:
4388                return input + 1
4389            else:
4390                return input + 2
4391
4392        def f2(input):
4393            return f1(input)
4394
4395        cnts = torch._dynamo.testing.CompileCounter()
4396        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4397        res1 = opt_f2(torch.tensor([1.0]))
4398        res2 = opt_f2(torch.tensor([0.0]))
4399
4400        self.assertEqual(res1, 3)
4401        self.assertEqual(res2, 1)
4402
4403    def test_frozenset_torch_func_contains(self):
4404        funcs = frozenset([torch.add])
4405
4406        def fn(x, func):
4407            if func in funcs:
4408                x = torch.add(x, 1.0)
4409            x = torch.mul(x, 1.0)
4410            return x
4411
4412        x = torch.randn(1)
4413        cnts = torch._dynamo.testing.CompileCounter()
4414        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4415        opt_fn(x, torch.add)
4416        self.assertEqual(cnts.op_count, 2)
4417
4418        torch._dynamo.reset()
4419        cnts = torch._dynamo.testing.CompileCounter()
4420        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4421        opt_fn(x, torch.mul)
4422        self.assertEqual(cnts.op_count, 1)
4423
4424    def test_inline_list_mutation(self):
4425        def f1(x):
4426            x.append(torch.ones(8))
4427            return x
4428
4429        def f2():
4430            x = [torch.ones(6)]
4431            f1(x)
4432            return x
4433
4434        res1 = f2()
4435        cnts = torch._dynamo.testing.CompileCounter()
4436        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4437        res2 = opt_f2()
4438        self.assertTrue(same(res1, res2))
4439
4440    def test_inline_dict_mutation(self):
4441        def f1(d):
4442            d["c"] = d["a"] + d.pop("b")
4443            return d
4444
4445        def f2():
4446            d = {"a": torch.ones(5), "b": torch.ones(5)}
4447            f1(d)
4448            return d
4449
4450        res1 = f2()
4451        cnts = torch._dynamo.testing.CompileCounter()
4452        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4453        res2 = opt_f2()
4454        self.assertTrue(same(res1, res2))
4455
4456    def test_inline_local_dict_clear(self):
4457        def f(d):
4458            d.clear()
4459            return d
4460
4461        inp = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)}
4462        out = torch.compile(f, backend="eager", fullgraph=True)(inp)
4463        self.assertEqual(len(out), 0)
4464        self.assertEqual(len(inp), 0)
4465
4466    def test_inline_module_attr_dict_clear(self):
4467        class MyMod(torch.nn.Module):
4468            def __init__(self):
4469                super().__init__()
4470                self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)}
4471
4472            def forward(self):
4473                self.a.clear()
4474                return self.a
4475
4476        m = MyMod()
4477        out = torch.compile(m, backend="eager", fullgraph=True)()
4478        self.assertEqual(len(out), 0)
4479        self.assertEqual(len(m.a), 0)
4480
4481    def test_inline_user_defined_dict_attr_clear(self):
4482        class MyMod:
4483            def __init__(self):
4484                self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)}
4485
4486        def f(obj, inp):
4487            ret = len(obj.a) + inp
4488            obj.a.clear()
4489            return obj.a, ret
4490
4491        m = MyMod()
4492        before_len = len(m.a)
4493        t_inp = torch.ones(1)
4494        d, ret = torch.compile(f, backend="eager", fullgraph=True)(m, t_inp)
4495        self.assertEqual(len(m.a), 0)
4496        self.assertEqual(len(d), 0)
4497        self.assertEqual(ret, t_inp + before_len)
4498
4499    def test_recursive_inline_list_mutation(self):
4500        def f1(x, y):
4501            x.append(torch.tensor([1.1]))
4502            y.append(torch.tensor([1.2]))
4503            return x, y
4504
4505        def f2(x, y):
4506            x.append(torch.tensor([2.1]))
4507            y.append(torch.tensor([2.2]))
4508            f1(x, y)
4509            return x, y
4510
4511        def f3(x):
4512            x.append(torch.tensor([3.1]))
4513            y = [torch.tensor([3.2])]
4514            f2(x, y)
4515            return x, y
4516
4517        def f4():
4518            x = [torch.tensor([4.1])]
4519            return f3(x)
4520
4521        res1 = f4()
4522        cnts = torch._dynamo.testing.CompileCounter()
4523        opt_f4 = torch._dynamo.optimize(cnts)(f4)
4524        res2 = opt_f4()
4525        self.assertTrue(same(res1, res2))
4526
4527    def test_sample_input(self):
4528        from torch.testing._internal.common_methods_invocations import SampleInput
4529
4530        def fn(sample):
4531            if isinstance(sample.input, torch.Tensor):
4532                return sample.input * 2
4533            return torch.zeros(())
4534
4535        sample = SampleInput(torch.ones(2))
4536        ref = fn(sample)
4537
4538        opt_fn = torch._dynamo.optimize("eager")(fn)
4539        res = opt_fn(sample)
4540
4541        self.assertTrue(same(ref, res))
4542
4543    def test_release_input_memory(self):
4544        x = torch.rand([4])
4545        x_ref = weakref.ref(x)
4546
4547        cnts = torch._dynamo.testing.CompileCounter()
4548
4549        @torch._dynamo.optimize(cnts)
4550        def foo(x):
4551            return x + x
4552
4553        out = foo(x)
4554        self.assertTrue(same(out, x + x))
4555        del x
4556        self.assertIs(x_ref(), None)
4557
4558    def test_release_module_memory(self):
4559        mod = torch.nn.Linear(10, 10)
4560        x = torch.rand([10, 10])
4561        mod_weight_ref = weakref.ref(mod.weight)
4562        mod_ref = weakref.ref(mod)
4563
4564        # Modules that are passed into torch._dynamo optimized functions
4565        # will normally be held onto through the generated GraphModule,
4566        # which contains the modules. remove the reference in this backend
4567        # and test that no additional references are being held.
4568        class NoLeakBackend:
4569            def __call__(self, gm: torch.fx.GraphModule, example_inputs):
4570                gm.mod = None
4571
4572                def foo(*args, **kwargs):
4573                    return (1,)
4574
4575                return foo
4576
4577        no_leak_backend = NoLeakBackend()
4578
4579        @torch._dynamo.optimize(no_leak_backend)
4580        def foo(mod, x):
4581            return mod(x)
4582
4583        foo(mod, x)
4584        del mod
4585        del x
4586        self.assertIsNone(mod_ref(), None)
4587        self.assertIsNone(mod_weight_ref(), None)
4588
4589    def test_release_scope_memory(self):
4590        def inner(y):
4591            y
4592
4593        inner = torch._dynamo.optimize("eager")(inner)
4594
4595        p_ref = None
4596
4597        x = torch.randn((10, 10))
4598        inner(x)
4599
4600        p_ref = weakref.ref(x)
4601        self.assertTrue(p_ref() is not None)
4602        del x
4603        self.assertTrue(p_ref() is None)
4604
4605    def test_update_locals_and_stack_uses_shared_cache(self):
4606        def fn(x):
4607            perm = [0, 3, 5]
4608            perm = list(range(min(perm))) + perm
4609            perm.extend(i for i in range(x.dim()) if i not in perm)
4610            return perm
4611
4612        x = torch.rand([2, 2, 2, 2, 2, 2])
4613        res1 = fn(x)
4614        cnts = torch._dynamo.testing.CompileCounter()
4615        opt_fn = torch._dynamo.optimize(cnts)(fn)
4616        res2 = opt_fn(x)
4617        self.assertTrue(same(res1, res2))
4618
4619    def test_dict_reconstruct_keeps_original_order(self):
4620        def fn():
4621            modules = collections.OrderedDict([("act", torch.nn.ReLU())])
4622            module_dict = torch.nn.ModuleDict(modules)
4623
4624            next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()}
4625            modules.update(next_modules.items())
4626            module_dict.update(next_modules)
4627            return modules, module_dict
4628
4629        cnts = torch._dynamo.testing.CompileCounter()
4630        opt_fn = torch._dynamo.optimize(cnts)(fn)
4631        modules, module_dict = opt_fn()
4632
4633        self.assertEqual(len(module_dict), len(modules))
4634        for k1, m2 in zip(modules, module_dict.children()):
4635            self.assertTrue(modules[k1] is m2)
4636
4637    def test_side_effects_codegen_update_mutated(self):
4638        # codegen to update mutated variables with side effect
4639        # should after stack value's codegen
4640        def f1(x):
4641            alist = [x]
4642            alist.append(x + 1)
4643            alist[0].sum().item()  # graph break
4644            res = alist.pop()
4645            res.sum().item()  # graph break
4646            return res
4647
4648        def f2(a, b):
4649            d = {"a": a + 1, "b": b + 2}
4650            x = d.pop("b")
4651            x.sum().item()  # graph break
4652            y = d["a"] + x
4653            y.sum().item()  # graph break
4654            d["c"] = y
4655            return d
4656
4657        x = torch.rand([2, 3])
4658        a = torch.rand([5, 6])
4659        b = torch.rand([5, 6])
4660        res11 = f1(x)
4661        res21 = f2(a, b)
4662        cnts = torch._dynamo.testing.CompileCounter()
4663        opt_f1 = torch._dynamo.optimize(cnts)(f1)
4664        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4665        res12 = opt_f1(x)
4666        res22 = opt_f2(a, b)
4667        self.assertTrue(same(res11, res12))
4668        self.assertTrue(same(res21, res22))
4669
4670    def test_list_append_return_none(self):
4671        def fn(x):
4672            alist = []
4673            blist = alist.append(x + 1)
4674            return alist, blist
4675
4676        x = torch.tensor([2.3])
4677        res = fn(x)
4678        cnts = torch._dynamo.testing.CompileCounter()
4679        opt_fn = torch._dynamo.optimize(cnts)(fn)
4680        res2 = opt_fn(x)
4681        self.assertEqual(res, res2)
4682
4683    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
4684    def test_tensor_ctor_list_of_tensor(self):
4685        def fn(x):
4686            return torch.tensor([x], dtype=torch.int64)
4687
4688        x = torch.tensor(20)
4689        res = fn(x)
4690        cnts = torch._dynamo.testing.CompileCounter()
4691        opt_fn = torch._dynamo.optimize(cnts)(fn)
4692        res2 = opt_fn(x)
4693        self.assertEqual(res, res2)
4694        self.assertEqual(cnts.frame_count, 1)
4695
4696    def test_tensor_types(self):
4697        def fn(dtype, tensor_type):
4698            x = torch.empty(4, dtype=dtype)
4699            assert isinstance(x, tensor_type)
4700
4701        opt_fn = torch._dynamo.optimize("eager")(fn)
4702        opt_fn(torch.float32, torch.FloatTensor)
4703        opt_fn(torch.float64, torch.DoubleTensor)
4704        opt_fn(torch.float16, torch.HalfTensor)
4705        opt_fn(torch.bfloat16, torch.BFloat16Tensor)
4706        opt_fn(torch.uint8, torch.ByteTensor)
4707        opt_fn(torch.int8, torch.CharTensor)
4708        opt_fn(torch.int64, torch.LongTensor)
4709        opt_fn(torch.int, torch.IntTensor)
4710        opt_fn(torch.int16, torch.ShortTensor)
4711        opt_fn(torch.bool, torch.BoolTensor)
4712
4713    def test_nan(self):
4714        def f(x, n):
4715            return x * 2 + n
4716
4717        x = torch.randn(4)
4718        n = float("nan")
4719
4720        cnts = torch._dynamo.testing.CompileCounter()
4721        opt_f = torch._dynamo.optimize(cnts)(f)
4722        opt_f(x, n)
4723        opt_f(x, n)
4724        self.assertEqual(cnts.frame_count, 1)
4725
4726    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
4727    def test_item(self):
4728        class MyMod(torch.nn.Module):
4729            def forward(self, x):
4730                z = torch.max(x)
4731                return z.int().item()
4732
4733        x = torch.tensor([[10.6763, 11.7445, -2.2369]])
4734        model = MyMod()
4735        y = torch._dynamo.optimize("eager", nopython=True)(model)(x)
4736
4737        self.assertEqual(y, 11)
4738
4739    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
4740    def test_item_changes(self):
4741        class MyMod(torch.nn.Module):
4742            def forward(self, x):
4743                z = torch.max(x)
4744                return z.int().item()
4745
4746        x = torch.tensor([[10.6763, 11.7445, -2.2369]])
4747        model = MyMod()
4748        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
4749        y = opt_model(x)
4750        z = opt_model(torch.tensor([[y - 5, y + 10, y + 50]]))
4751
4752        self.assertEqual(y, 11)
4753        self.assertEqual(z, 61)
4754
4755    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
4756    def test_item_changes_new_shape(self):
4757        class MyMod(torch.nn.Module):
4758            def forward(self, x):
4759                z = torch.max(x)
4760                return z.int().item()
4761
4762        x = torch.tensor([[10.6763, 11.7445, -2.2369]])
4763        model = MyMod()
4764        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
4765        y = opt_model(x)
4766        z = opt_model(torch.tensor([[y - 5, y + 50], [y + 5, y - 50]]))
4767
4768        self.assertEqual(y, 11)
4769        self.assertEqual(z, 61)
4770
4771    @unittest.skip("https://github.com/pytorch/pytorch/issues/99726")
4772    def test_cross_entropy_loss_fancy_ctor1(self):
4773        rand_5 = torch.randn(5)
4774        rand_3_5 = torch.randn(3, 5)
4775        target = torch.empty(3, dtype=torch.long).random_(5)
4776
4777        loss = torch.nn.CrossEntropyLoss(
4778            weight=rand_5, reduce=False, label_smoothing=0.5
4779        )
4780        opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
4781        input = rand_3_5
4782        dynamo_output = opt_loss(input, target)
4783
4784        loss = torch.nn.CrossEntropyLoss(
4785            weight=rand_5, reduce=False, label_smoothing=0.5
4786        )
4787        input = rand_3_5
4788        output = loss(input, target)
4789
4790        self.assertTrue(torch.allclose(dynamo_output, output))
4791
4792    def test_cross_entropy_loss_fancy_ctor2(self):
4793        rand_3_5 = torch.randn(3, 5)
4794        target = torch.empty(3, dtype=torch.long).random_(5)
4795
4796        loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5)
4797        opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
4798        input = rand_3_5
4799        dynamo_output = opt_loss(input, target)
4800
4801        loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5)
4802        input = rand_3_5
4803        output = loss(input, target)
4804
4805        self.assertTrue(torch.allclose(dynamo_output, output))
4806
4807    def test_cross_entropy_loss_simple_ctor(self):
4808        output = None
4809        rand_3_5 = torch.randn(3, 5)
4810        target = torch.empty(3, dtype=torch.long).random_(5)
4811
4812        loss = torch.nn.CrossEntropyLoss()
4813        opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
4814        input = rand_3_5
4815        dynamo_output = opt_loss(input, target)
4816
4817        loss = torch.nn.CrossEntropyLoss()
4818        input = rand_3_5
4819        output = loss(input, target)
4820
4821        self.assertTrue(torch.allclose(dynamo_output, output))
4822
4823    def test_nn_functional_reduction(self):
4824        def fn(loss, reduction):
4825            reduction_enum = F._Reduction.get_enum(reduction)
4826            if reduction_enum == 0:
4827                return loss
4828            elif reduction_enum == 1:
4829                return loss.mean()
4830            elif reduction_enum == 2:
4831                return loss.sum()
4832
4833        x = torch.rand([3, 5])
4834        y = "mean"
4835        ref = fn(x, y)
4836        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
4837        res = opt_fn(x, y)
4838        self.assertTrue(torch.allclose(ref, res))
4839
4840    def test_large_reduction_list(self):
4841        dtype = torch.float32
4842        device = "cpu"
4843
4844        def check_sum_all(tensor: torch.Tensor) -> None:
4845            pylist = tensor.reshape(-1).tolist()
4846            self.assertTrue(same(tensor.sum(), torch.tensor(sum(pylist))))
4847
4848        check_sum_all(torch.randn(200000, dtype=dtype, device=device))
4849
4850    def test_raise_on_backend_error(self):
4851        def my_compiler(gm, _):
4852            raise RuntimeError("duck!")
4853
4854        @torch._dynamo.optimize(my_compiler)
4855        def fn(a, b):
4856            return a + b / (a - b)
4857
4858        self.assertRaises(
4859            torch._dynamo.exc.BackendCompilerFailed,
4860            lambda: fn(torch.randn(10), torch.randn(10)),
4861        )
4862
4863    def test_named_parameters(self):
4864        n_embd = 768
4865        block_size = 128
4866        vocab_size = 65
4867        embd_pdrop = 0.1
4868
4869        class MyModel2(torch.nn.Module):
4870            def __init__(self):
4871                super().__init__()
4872                self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
4873                self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
4874                self.drop = torch.nn.Dropout(embd_pdrop)
4875
4876            def forward(self, x):
4877                return x
4878
4879        class MyModel(torch.nn.Module):
4880            def __init__(self):
4881                super().__init__()
4882                self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
4883                self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
4884                self.drop = torch.nn.Dropout(embd_pdrop)
4885                self.submod2 = MyModel2()
4886
4887            def forward(self, x):
4888                return x
4889
4890        # Regular
4891        params = []
4892        mod = MyModel()
4893        actual_params = list(mod.named_parameters())
4894
4895        @torch._dynamo.optimize("eager", nopython=True)
4896        def fn():
4897            return list(mod.named_parameters())
4898
4899        params = fn()
4900
4901        self.assertEqual(len(actual_params), len(params))
4902        for idx in range(len(params)):
4903            k_a, v_a = actual_params[idx]
4904            k, v = params[idx]
4905            self.assertEqual(k_a, k)
4906            self.assertTrue(torch.allclose(v_a, v))
4907
4908        # Prefix
4909        params = []
4910        mod = MyModel()
4911        actual_params = list(mod.named_parameters(prefix="foo"))
4912
4913        @torch._dynamo.optimize("eager", nopython=True)
4914        def fn1():
4915            return list(mod.named_parameters(prefix="foo"))
4916
4917        params = fn1()
4918
4919        self.assertEqual(len(actual_params), len(params))
4920        for idx in range(len(params)):
4921            k_a, v_a = actual_params[idx]
4922            k, v = params[idx]
4923            self.assertEqual(k_a, k)
4924            self.assertTrue(torch.allclose(v_a, v))
4925
4926    @torch._dynamo.config.patch(guard_nn_modules=True)
4927    def test_module_complex_iter(self):
4928        n_embd = 768
4929        block_size = 128
4930        vocab_size = 65
4931        embd_pdrop = 0.1
4932
4933        class FakeGPT(torch.nn.Module):
4934            def __init__(self):
4935                super().__init__()
4936                self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
4937                self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
4938                self.drop = torch.nn.Dropout(embd_pdrop)
4939                self.ln_f = torch.nn.LayerNorm(n_embd)
4940                self.head = torch.nn.Linear(n_embd, vocab_size, bias=False)
4941
4942                self.block_size = block_size
4943                self.names = []
4944
4945            def forward(self, idx, targets=None):
4946                b, t = idx.size()
4947                assert (
4948                    t <= self.block_size
4949                ), "Cannot forward, model block size is exhausted."
4950
4951                # forward the GPT model
4952                token_embeddings = self.tok_emb(
4953                    idx
4954                )  # each index maps to a (learnable) vector
4955                position_embeddings = self.pos_emb[
4956                    :, :t, :
4957                ]  # each position maps to a (learnable) vector
4958                x = self.drop(token_embeddings + position_embeddings)
4959                x = self.blocks(x)
4960                x = self.ln_f(x)
4961                logits = self.head(x)
4962
4963                # if we are given some desired targets also calculate the loss
4964                loss = None
4965                if targets is not None:
4966                    loss = F.cross_entropy(
4967                        logits.view(-1, logits.size(-1)), targets.view(-1)
4968                    )
4969
4970                return logits, loss
4971
4972            def foo(self, memo=None, prefix="", remove_duplicate=False):
4973                for mn, m in self.named_modules(
4974                    memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
4975                ):
4976                    for pn, p in self.named_parameters():
4977                        fpn = f"{mn}.{pn}" if mn else pn
4978                        self.names.append(fpn)
4979
4980        # Test plain recurse
4981        model_a = FakeGPT()
4982        model_a.foo()
4983        a_names = model_a.names
4984
4985        model_b = FakeGPT()
4986        opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b)
4987        opt_model_b.foo()
4988
4989        self.assertEqual(a_names, model_b.names)
4990
4991        # Test with prefix
4992        model_a = FakeGPT()
4993        model_a.foo(prefix="abc")
4994        a_names = model_a.names
4995
4996        model_b = FakeGPT()
4997        opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b)
4998        opt_model_b.foo(prefix="abc")
4999
5000        self.assertEqual(a_names, model_b.names)
5001
5002    def test_numpy_variable_isinstance(self):
5003        def fn(x, m):
5004            if isinstance(m, np.ndarray):
5005                return x + 1
5006            else:
5007                return x - 1
5008
5009        x = torch.tensor([2.3])
5010        m = np.array([1, 2, 3])
5011        ref = fn(x, m)
5012        cnts = torch._dynamo.testing.CompileCounter()
5013        opt_fn = torch._dynamo.optimize(cnts)(fn)
5014        res = opt_fn(x, m)
5015        self.assertEqual(ref, res)
5016
5017        # Test now the other path
5018        ref = fn(x, x)
5019        res = opt_fn(x, x)
5020        self.assertEqual(ref, res)
5021
5022    def test_tensor_dot_grad_no_graph_break(self):
5023        def fn(a, b):
5024            y = 3 * a**3 - b**2
5025            y.backward(gradient=torch.tensor([1.0, 1.0]))
5026            b.grad.zero_()
5027            return a.grad, b.grad
5028
5029        a = torch.tensor([2.0, 3.0], requires_grad=True)
5030        b = torch.tensor([6.0, 4.0], requires_grad=True)
5031        cnts = torch._dynamo.testing.CompileCounter()
5032        opt_fn = torch._dynamo.optimize(cnts)(fn)
5033        _, b_grad = opt_fn(a, b)
5034        self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0])))
5035        self.assertEqual(cnts.frame_count, 2)
5036
5037    def test_torch_nn_parameter_isinstance(self):
5038        def fn(x):
5039            a = torch.nn.Parameter(torch.rand(2, 3))
5040            if isinstance(a, torch.Tensor):
5041                return x + 1
5042            else:
5043                return x - 1
5044
5045        x = torch.tensor([2.5])
5046        ref = fn(x)
5047        opt_fn = torch._dynamo.optimize("eager")(fn)
5048        res = opt_fn(x)
5049        self.assertEqual(ref, res)
5050
5051    def _optimize_then_check_exp(
5052        self, foo, args, cnt, exp_out, exp_frame_count, exp_n_cached_backend
5053    ):
5054        opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args)
5055        self.assertEqual(exp_out, opt_out)
5056        self.assertEqual(cnt.frame_count, exp_frame_count)
5057
5058    def test_backend_match_guard(self):
5059        x = torch.randn([3, 4])
5060
5061        def foo(x):
5062            return x.sin() + x.cos()
5063
5064        def foo_graph_break(x):
5065            a = x.sin()
5066            torch._dynamo.graph_break()
5067            b = x.cos()
5068            return a + b
5069
5070        eager_record_backend = torch._dynamo.testing.EagerAndRecordGraphs()
5071        backends = [eager_record_backend, "eager"]
5072
5073        # We intentionally don't reset dynamo for each backend so that we can test
5074        # 1. dynamo doesn't recompile when backend stays the same, i.e. frame_count doesn't increase
5075        # 2. dynamo recompiles when backend changes, i.e. frame_count is non-zero for next backend
5076        def test_recompile(foo, *, exp_frame_count):
5077            eager_result = foo(x)
5078            for i, backend in enumerate(backends):
5079                cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
5080                # Run opt_f multiple times to make sure dynamo doesn't recompile.
5081                # Specifically, frame_count doesn't increase
5082                # the number of cached backends is i + 2 because we have the optimizing backend + None
5083                self._optimize_then_check_exp(
5084                    foo, (x,), cnt, eager_result, exp_frame_count, i + 2
5085                )
5086                self._optimize_then_check_exp(
5087                    foo, (x,), cnt, eager_result, exp_frame_count, i + 2
5088                )
5089                self._optimize_then_check_exp(
5090                    foo, (x,), cnt, eager_result, exp_frame_count, i + 2
5091                )
5092
5093        test_recompile(foo, exp_frame_count=1)
5094        torch._dynamo.reset()
5095        test_recompile(foo_graph_break, exp_frame_count=2)
5096
5097    def test_backend_match_guard_multi_threads(self):
5098        x = torch.randn([3, 4])
5099
5100        def foo(x):
5101            return x.sin() + x.cos()
5102
5103        def compile_then_check_exp(foo, args, cnt, eager_result, exp_frame_count):
5104            for i in range(3):
5105                opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args)
5106                self.assertEqual(opt_out, eager_result)
5107            self.assertEqual(cnt.frame_count, exp_frame_count)
5108            thread_success[threading.current_thread()] = True
5109
5110        eager_record_backend = torch._dynamo.testing.EagerAndRecordGraphs()
5111        backends = [eager_record_backend, "eager"]
5112
5113        # Test dynamo recompiles but only caches a single backend for each thread
5114        eager_result = foo(x)
5115        # cnt and None
5116        exp_frame_count = 1
5117        threads = []
5118        thread_success = {}
5119        for i, backend in enumerate(backends):
5120            cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
5121            thread = threading.Thread(
5122                target=compile_then_check_exp,
5123                args=(
5124                    foo,
5125                    (x,),
5126                    cnt,
5127                    eager_result,
5128                    exp_frame_count,
5129                ),
5130            )
5131            threads.append(thread)
5132            thread.start()
5133
5134        # Wait for all threads to finish
5135        for thread in threads:
5136            thread.join()
5137
5138        self.assertEqual(len(thread_success), len(threads))
5139
5140    def test_dynamo_min_operator_with_shape(self):
5141        @torch._dynamo.optimize("eager", nopython=True)
5142        def f(x, a):
5143            return min(x.shape[0], a)
5144
5145        result = f(torch.ones(6), 3)
5146        self.assertEqual(result, 3)
5147
5148    def test_onnx_shape_as_tensor(self):
5149        @torch._dynamo.optimize("eager", nopython=True)
5150        def f(x):
5151            return 1 + torch._shape_as_tensor(x)[0]
5152
5153        gm, _ = torch._dynamo.export(f)(torch.ones(6))
5154
5155        input_one_dim = torch.ones(6)
5156        input_two_dims = torch.ones(7, 4)
5157        self.assertEqual(f(input_one_dim), 7)
5158        self.assertEqual(f(input_two_dims), 8)
5159        self.assertEqual(f(input_two_dims), 8)
5160
5161        @torch._dynamo.optimize("eager", nopython=True)
5162        def f_onnx(x):
5163            return 1 + torch.onnx.operators.shape_as_tensor(x)[0]
5164
5165        self.assertEqual(f_onnx(input_one_dim), 7)
5166        self.assertEqual(f_onnx(input_two_dims), 8)
5167        self.assertEqual(f_onnx(input_two_dims), 8)
5168
5169    def test_cond(self):
5170        from functorch.experimental.control_flow import cond
5171
5172        def true_fn(x):
5173            return x.sin()
5174
5175        def false_fn(x):
5176            return x.cos()
5177
5178        def f(pred, x):
5179            return cond(pred, true_fn, false_fn, [x])
5180
5181        opt_fn = torch._dynamo.optimize("eager")(f)
5182        a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25]))
5183        self.assertTrue(same(torch.cos(torch.tensor([0.25, 0.25])), a))
5184        b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25]))
5185        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b))
5186
5187    def test_nonzero_static(self):
5188        # invalid size
5189        with self.assertRaisesRegex(
5190            RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
5191        ):
5192            torch.nonzero_static(torch.tensor([8]), size=-2)
5193
5194        with self.assertRaisesRegex(
5195            RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
5196        ):
5197            torch.nonzero_static(torch.tensor([8]), size=-2, out=torch.tensor(0))
5198
5199        # nonzero_static.out: out dtype mismatch
5200        input_tensor = torch.tensor([8])
5201        static_size = 1
5202        out_tensor = torch.empty((static_size, input_tensor.dim()), dtype=torch.float)
5203        with self.assertRaisesRegex(
5204            RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long"
5205        ):
5206            torch.nonzero_static(input_tensor, size=static_size, out=out_tensor)
5207
5208        # nonzero_static.out: out resize (shrink)
5209        input_tensor = torch.tensor([8])
5210        static_size = 1
5211        out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long)
5212        self.assertTrue(
5213            same(
5214                torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
5215                torch.tensor([0]),
5216            )
5217        )
5218        self.assertTrue(
5219            same(
5220                out_tensor,
5221                torch.tensor([0]),
5222            )
5223        )
5224
5225        # nonzero_static.out: out resize (enlarge)
5226        input_tensor = torch.tensor([8])
5227        static_size = 1
5228        out_tensor = torch.empty((0), dtype=torch.long)
5229        self.assertTrue(
5230            same(
5231                torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
5232                torch.tensor([0]),
5233            )
5234        )
5235        self.assertTrue(
5236            same(
5237                out_tensor,
5238                torch.tensor([0]),
5239            )
5240        )
5241
5242        # 0 rank
5243        input_tensor = torch.tensor(6)
5244        static_size = 2
5245        self.assertTrue(
5246            same(
5247                torch.nonzero_static(input_tensor, size=static_size),
5248                torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
5249            )
5250        )
5251
5252        # 0 size
5253        input_tensor = torch.tensor([[[1]]])
5254        static_size = 0
5255        self.assertTrue(
5256            same(
5257                torch.nonzero_static(input_tensor, size=static_size),
5258                torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
5259            )
5260        )
5261
5262        # 1D input
5263        input_tensor = torch.tensor([0, 8])
5264        static_size = 1
5265        self.assertTrue(
5266            same(
5267                torch.nonzero_static(input_tensor, size=static_size),
5268                torch.tensor([1]),
5269            )
5270        )
5271
5272        input_tensor = torch.tensor([8, 0])
5273        static_size = 2
5274        self.assertTrue(
5275            same(
5276                torch.nonzero_static(input_tensor, size=static_size),
5277                torch.tensor([[0], [-1]]),  # padded with default fill_value "-1"
5278            )
5279        )
5280
5281        # 2D input
5282        input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
5283        static_size = 5
5284        fill_value = -100
5285        self.assertTrue(
5286            torch._dynamo.utils.same(
5287                torch.nonzero_static(
5288                    input_tensor, size=static_size, fill_value=fill_value
5289                ),
5290                torch.tensor(
5291                    [
5292                        [0, 0],
5293                        [1, 0],
5294                        [1, 1],
5295                        [fill_value, fill_value],
5296                        [fill_value, fill_value],
5297                    ]
5298                ),
5299            )
5300        )
5301        input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
5302        static_size = 2
5303        fill_value = -100
5304        self.assertTrue(
5305            torch._dynamo.utils.same(
5306                torch.nonzero_static(
5307                    input_tensor, size=static_size, fill_value=fill_value
5308                ),
5309                torch.tensor([[0, 0], [1, 0]]),
5310            )
5311        )
5312
5313        # 3D input
5314        input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]])
5315        static_size = 4
5316        fill_value = -999
5317        self.assertTrue(
5318            torch._dynamo.utils.same(
5319                torch.nonzero_static(
5320                    input_tensor,
5321                    size=static_size,
5322                    fill_value=fill_value,
5323                ),
5324                torch.tensor(
5325                    [
5326                        [0, 1, 1],
5327                        [1, 1, 0],
5328                        [fill_value, fill_value, fill_value],
5329                        [fill_value, fill_value, fill_value],
5330                    ]
5331                ),
5332            )
5333        )
5334
5335    def test_cond_with_quantization(self):
5336        from functorch.experimental.control_flow import cond
5337
5338        class MyModule(torch.nn.Module):
5339            def __init__(self):
5340                super().__init__()
5341                example_inputs = (torch.randn(5, 5),)
5342                self.model = torch.nn.Linear(5, 5)
5343                self.quantized_model = prepare_qat_fx(
5344                    self.model, qconfig_dict, example_inputs=example_inputs
5345                )
5346
5347            def forward(self, pred, x):
5348                def true_fn(x):
5349                    return x.sin() + self.quantized_model(x)
5350
5351                def false_fn(x):
5352                    return x.cos() + self.model(x)
5353
5354                return cond(pred, true_fn, false_fn, [x])
5355
5356        module = MyModule()
5357        opt_m = torch._dynamo.optimize("eager", nopython=True)(module)
5358        x = torch.rand((5, 5))
5359        pred = torch.tensor(True)
5360        self.assertTrue(same(module(pred, x), opt_m(pred, x)))
5361        pred = torch.tensor(False)
5362        self.assertTrue(same(module(pred, x), opt_m(pred, x)))
5363
5364    def test_map_with_quantization(self):
5365        from functorch.experimental.control_flow import map
5366
5367        class MyModule(torch.nn.Module):
5368            def __init__(self):
5369                super().__init__()
5370                example_inputs = (torch.randn(5, 5),)
5371                self.model = torch.nn.Linear(5, 5)
5372                self.quantized_model = prepare_qat_fx(
5373                    self.model, qconfig_dict, example_inputs=example_inputs
5374                )
5375
5376            def forward(self, x):
5377                def body(x):
5378                    return x.sin() + self.quantized_model(x)
5379
5380                return map(body, x)
5381
5382        module = MyModule()
5383        opt_m = torch._dynamo.optimize("eager", nopython=True)(module)
5384        x = torch.rand((5, 5))
5385        self.assertTrue(same(module(x), opt_m(x)))
5386
5387    def test_cond_side_effects(self):
5388        from functorch.experimental.control_flow import cond
5389
5390        c = 0
5391
5392        def true_fn(x):
5393            return x - c
5394
5395        def false_fn(x):
5396            return x + c
5397
5398        def f(pred, x):
5399            nonlocal c
5400            c = 1
5401            return cond(pred, true_fn, false_fn, [x])
5402
5403        opt_fn = torch._dynamo.optimize("eager")(f)
5404        c = 0
5405        a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25]))
5406        self.assertTrue(same(torch.tensor([1.25, 1.25]), a))
5407
5408    def test_map_side_effects(self):
5409        from functorch.experimental.control_flow import map
5410
5411        class Module(torch.nn.Module):
5412            def __init__(self):
5413                super().__init__()
5414                self.w = torch.tensor(1)
5415
5416            def forward(self, xs):
5417                def body(x):
5418                    self.w += 1
5419                    return x
5420
5421                return map(body, xs)
5422
5423        mod = Module()
5424
5425        error_message = ""
5426        if torch._dynamo.config.inline_inbuilt_nn_modules:
5427            error_message = r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)"
5428        else:
5429            error_message = "Can't inplace modify module params/buffers"
5430
5431        with self.assertRaisesRegex(Unsupported, error_message):
5432            opt_fn = torch._dynamo.optimize("eager", nopython=True)(mod)
5433            opt_fn(torch.randn(3, 2))
5434
5435    def test_cond_nested(self):
5436        from functorch.experimental.control_flow import cond
5437
5438        def true_fn_nested(x):
5439            return x * 10
5440
5441        def false_fn_nested(x):
5442            return x * -1
5443
5444        def true_fn(pred2, x):
5445            return x.sin()
5446
5447        def false_fn(pred2, x):
5448            return x + cond(pred2, true_fn_nested, false_fn_nested, [x])
5449
5450        def f(pred, pred2, x):
5451            return cond(pred, true_fn, false_fn, [pred2, x])
5452
5453        cc = torch._dynamo.testing.CompileCounter()
5454        opt_fn = torch._dynamo.optimize(cc)(f)
5455        true_true_sin = opt_fn(
5456            torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
5457        )
5458        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin))
5459
5460        true_false_sin = opt_fn(
5461            torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25])
5462        )
5463        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin))
5464
5465        false_true_sum_mult = opt_fn(
5466            torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
5467        )
5468        self.assertTrue(
5469            same(torch.tensor([2.75, 2.75]), false_true_sum_mult)
5470        )  # * 10 then add x
5471
5472        false_false_sum_neg = opt_fn(
5473            torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25])
5474        )
5475        self.assertTrue(
5476            same(torch.tensor([0.0, 0.0]), false_false_sum_neg)
5477        )  # * -1 then add x
5478        self.assertTrue(cc.frame_count, 2)
5479
5480    def test_cond_export(self):
5481        from functorch.experimental.control_flow import cond
5482
5483        def true_fn_nested(x):
5484            return x * 10
5485
5486        def false_fn_nested(x):
5487            return x * -1
5488
5489        def true_fn(pred2, x):
5490            return x.sin()
5491
5492        def false_fn(pred2, x):
5493            return x + cond(pred2, true_fn_nested, false_fn_nested, [x])
5494
5495        def f(pred, pred2, x):
5496            return cond(pred, true_fn, false_fn, [pred2, x])
5497
5498        graph, guard = torch._dynamo.export(f)(
5499            torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
5500        )
5501        true_true_sin = graph(
5502            torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
5503        )
5504        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin))
5505
5506        true_false_sin = graph(
5507            torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25])
5508        )
5509        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin))
5510
5511        false_true_sum_mult = graph(
5512            torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
5513        )
5514        self.assertTrue(
5515            same(torch.tensor([2.75, 2.75]), false_true_sum_mult)
5516        )  # * 10 then add x
5517
5518        false_false_sum_neg = graph(
5519            torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25])
5520        )
5521        self.assertTrue(
5522            same(torch.tensor([0.0, 0.0]), false_false_sum_neg)
5523        )  # * -1 then add x
5524
5525    def test_cond_export_single_arg(self):
5526        from functorch.experimental.control_flow import cond
5527
5528        def true_fn(x):
5529            return x
5530
5531        def false_fn(x):
5532            return x.sin()
5533
5534        def f(pred, x):
5535            return cond(pred, true_fn, false_fn, [x])
5536
5537        graph, guard = torch._dynamo.export(f)(
5538            torch.tensor(False), torch.tensor([0.25, 0.25])
5539        )
5540        true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25]))
5541        self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror))
5542        true_mirror_2 = graph(torch.tensor(True), torch.tensor([0.33, 0.33, 0.33]))
5543        self.assertTrue(same(torch.tensor([0.33, 0.33, 0.33]), true_mirror_2))
5544
5545        false_sin = graph(torch.tensor(False), torch.tensor([0.5, 0.5]))
5546        self.assertTrue(same(torch.sin(torch.tensor([0.5, 0.5])), false_sin))
5547
5548    def test_enum_guards(self):
5549        class MyEnum(enum.Enum):
5550            FOO = 10
5551            BAR = 20
5552
5553        def fn(x, y):
5554            if y == MyEnum.FOO:
5555                return x + 1
5556            else:
5557                return x - 1
5558
5559        x = torch.rand(3)
5560        y = MyEnum.BAR
5561        ref = fn(x, y)
5562        opt_fn = torch.compile(backend="eager")(fn)
5563        res = opt_fn(x, y)
5564        self.assertTrue(same(ref, res))
5565
5566    def test_duplicate_graph_break_log(self):
5567        torch._logging.set_logs(graph_breaks=True)
5568
5569        @torch._dynamo.optimize("eager")
5570        def f1(a, b):
5571            f2(a, b)
5572
5573        def f2(a, b):
5574            c = a + b
5575            print("break")
5576            return a + b + c
5577
5578        @torch._dynamo.optimize("eager")
5579        def g1(a, b):
5580            g2(a, b)
5581
5582        def g2(a, b):
5583            c = a + b
5584            print("break")
5585            return a + b + c
5586
5587        def count_graph_break_msgs(msgs):
5588            return sum(msg.find("Graph break") != -1 for msg in msgs)
5589
5590        with self.assertLogs(
5591            logger="torch._dynamo", level=logging.DEBUG
5592        ) as log, torch._dynamo.config.patch(verbose=True):
5593            f1(torch.randn(10), torch.randn(10))
5594            self.assertGreater(count_graph_break_msgs(log.output), 1)
5595
5596        with self.assertLogs(
5597            logger="torch._dynamo", level=logging.DEBUG
5598        ) as log, torch._dynamo.config.patch(verbose=False):
5599            g1(torch.randn(10), torch.randn(10))
5600            self.assertEqual(count_graph_break_msgs(log.output), 1)
5601
5602        # reset logging state
5603        torch._logging.set_logs()
5604
5605    def test_inplace_param_update(self):
5606        def fn(param, y):
5607            prev_grad = torch.is_grad_enabled()
5608            try:
5609                torch.set_grad_enabled(False)
5610                torch.set_grad_enabled(True)
5611                torch.set_grad_enabled(False)
5612                param.add_(y)
5613            finally:
5614                torch.set_grad_enabled(prev_grad)
5615
5616        y = torch.randn(4)
5617        x = torch.nn.Parameter(torch.randn(4))
5618        fn(x, y)
5619
5620        cnts = torch._dynamo.testing.CompileCounter()
5621        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
5622        opt_fn(x, y)
5623        self.assertEqual(cnts.frame_count, 1)
5624        self.assertEqual(cnts.op_count, 3)
5625
5626    @unittest.skipIf(
5627        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
5628        "Can't run fused SDPA on this platform",
5629    )
5630    def test_parsing_sdpa(self):
5631        class MyModule(torch.nn.Module):
5632            def forward(self, query, key, value):
5633                out = F.scaled_dot_product_attention(query, key, value, None, 0, True)
5634                out = F.scaled_dot_product_attention(
5635                    query, key, value, None, 0, True, scale=8
5636                )
5637                out = F.scaled_dot_product_attention(
5638                    query=query,
5639                    key=key,
5640                    value=value,
5641                    attn_mask=None,
5642                    dropout_p=0,
5643                    is_causal=True,
5644                )
5645                out = F.scaled_dot_product_attention(
5646                    query,
5647                    key=key,
5648                    value=value,
5649                    attn_mask=None,
5650                    dropout_p=0,
5651                    is_causal=True,
5652                )
5653                out = F.scaled_dot_product_attention(
5654                    query, key, value, None, dropout_p=0, is_causal=True
5655                )
5656                out = F.scaled_dot_product_attention(query, key, value, None, scale=8)
5657                return out
5658
5659        device = "cuda"
5660        dtype = torch.float16
5661        seq_len_q = 1
5662        seq_len_k = 1
5663        head_dim = 8
5664        query = torch.ones(
5665            1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True
5666        )
5667        key = torch.ones(
5668            1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
5669        )
5670        value = torch.ones(
5671            1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
5672        )
5673        module = MyModule()
5674        opt_mod = torch._dynamo.optimize("inductor")(module)
5675        opt_mod(query, key, value)
5676
5677    def test_generate_tensor_from_list_of_numpy_primitive_type(self):
5678        # Test sth like torch.LongTensor(list(np.int64, np.int64, ...))
5679        def fn():
5680            x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64)
5681            y = [x[0], x[2], x[4]]
5682            return torch.LongTensor(y)
5683
5684        ref = fn()
5685        res = torch.compile(fullgraph=True)(fn)()
5686        self.assertEqual(ref, res)
5687
5688    def test_object_classmethod(self):
5689        class C:
5690            @classmethod
5691            def fn(cls, x):
5692                return x + x
5693
5694        @torch._dynamo.optimize("eager", nopython=True)
5695        def f():
5696            return C().fn(torch.ones(2, 3))
5697
5698        self.assertTrue(torch.allclose(f(), torch.tensor([2.0])))
5699
5700    def test_object_staticmethod(self):
5701        class C:
5702            @staticmethod
5703            def fn(x):
5704                return x + x
5705
5706        @torch._dynamo.optimize("eager", nopython=True)
5707        def f():
5708            return C().fn(torch.ones(2, 3))
5709
5710        self.assertTrue(torch.allclose(f(), torch.tensor([2.0])))
5711
5712    def test_user_function_variable_supports_enum_argument(self):
5713        class Foo(enum.Enum):
5714            FOO = 0
5715            BAR = 1
5716
5717        def gn(x, y=Foo.FOO):
5718            if y is Foo.FOO:
5719                return x
5720            else:
5721                return x + 1
5722
5723        def fn(x):
5724            return gn(x)
5725
5726        x = torch.randn(2, 3)
5727        ref = fn(x)
5728        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
5729        res = opt_fn(x)
5730        self.assertTrue(torch.allclose(ref, res))
5731
5732    def test_user_function_variable_supports_type_abcmeta_argument(self):
5733        class Foo(metaclass=abc.ABCMeta):
5734            @abc.abstractclassmethod
5735            def read(self):  # noqa: B027
5736                pass
5737
5738        class Bar(Foo):
5739            def read(self):
5740                return "Hello World!"
5741
5742        class Baz:
5743            pass
5744
5745        def gn(x, tys=(Bar, Baz)):
5746            if Bar in tys:
5747                return x - 1
5748            else:
5749                return x + 1
5750
5751        def fn(x):
5752            return gn(x)
5753
5754        x = torch.randn(2, 3)
5755        ref = fn(x)
5756        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
5757        res = opt_fn(x)
5758        self.assertTrue(torch.allclose(ref, res))
5759
5760    def test_user_function_variable_supports_function_argument(self):
5761        # Test user defined function default arguments can be:
5762        # 1, user defined functions (e.g, add1)
5763        # 2, torch functions (e.g, torch.sin)
5764        # 3, python builtin functions (e.g, operator.neg)
5765        def add1(x):
5766            return x + 1
5767
5768        def gn(x, f1=add1, f2=torch.sin, f3=operator.neg):
5769            return f3(f2(f1(x)))
5770
5771        def fn(x):
5772            return gn(x)
5773
5774        x = torch.randn(2, 3)
5775        ref = fn(x)
5776        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
5777        res = opt_fn(x)
5778        self.assertTrue(torch.allclose(ref, res))
5779
5780    def test_typing_variable_isinstance(self):
5781        def fn(x, m):
5782            if isinstance(m, typing.Mapping):
5783                return x + 1
5784            else:
5785                return x - 1
5786
5787        x = torch.randn(2, 3)
5788        m = {"x": torch.randn(3)}
5789        ref = fn(x, m)
5790        opt_fn = torch._dynamo.optimize("eager")(fn)
5791        res = opt_fn(x, m)
5792        self.assertTrue(torch.allclose(ref, res))
5793
5794    @torch._dynamo.config.patch(guard_nn_modules=True)
5795    def test_repro_graph_breaks_in__get_item_by_idx(self):
5796        class Mod(torch.nn.Module):
5797            def __init__(self):
5798                super().__init__()
5799                self.mod = torch.nn.Sequential(
5800                    torch.nn.Linear(3, 3), torch.nn.Linear(3, 3)
5801                )
5802
5803            def forward(self, x):
5804                return self.mod[0](x)
5805
5806        m = Mod()
5807        graph, _ = torch._dynamo.export(m)(torch.randn(3, 3))
5808
5809    @torch._dynamo.config.patch(guard_nn_modules=True)
5810    def test_nn_sequential_invocation(self):
5811        with freeze_rng_state():
5812
5813            class TestModel(torch.nn.Module):
5814                def __init__(self) -> None:
5815                    super().__init__()
5816                    self.linears = torch.nn.Sequential(
5817                        torch.nn.Linear(2, 2),
5818                        torch.nn.Linear(2, 2),
5819                        torch.nn.Linear(2, 2),
5820                        torch.nn.Linear(2, 2),
5821                    )
5822
5823                def forward(self, x):
5824                    all_but_last = self.linears[:-1]
5825                    return all_but_last(x)
5826
5827            m = TestModel()
5828            x = torch.rand((2, 2))
5829            real = m(x)
5830            graph, _ = torch._dynamo.export(m)(x)
5831            dynamo_result = graph(x)
5832            self.assertTrue(same(real, dynamo_result))
5833
5834    @torch._dynamo.config.patch(guard_nn_modules=True)
5835    def test_nn_sequential_invocation_reposition_indices(self):
5836        with freeze_rng_state():
5837
5838            class TestModel(torch.nn.Module):
5839                def __init__(self) -> None:
5840                    super().__init__()
5841                    self.linears = torch.nn.Sequential(
5842                        torch.nn.Linear(2, 2),
5843                        torch.nn.Linear(2, 2),
5844                        torch.nn.Linear(2, 2),
5845                        torch.nn.Linear(2, 2),
5846                    )
5847
5848                def forward(self, x):
5849                    all_but_last = self.linears[1:3]
5850                    return all_but_last(x)
5851
5852            m = TestModel()
5853            x = torch.rand((2, 2))
5854            real = m(x)
5855            graph, _ = torch._dynamo.export(m)(x)
5856            dynamo_result = graph(x)
5857            self.assertTrue(same(real, dynamo_result))
5858
5859    def test_error_on_nested_fx_trace(self):
5860        input = torch.rand(2, 3)
5861
5862        def f(x):
5863            x + x
5864
5865        real = f(input)
5866
5867        optimized = torch._dynamo.optimize("eager")(f)
5868        self.assertTrue(same(optimized(input), real))
5869
5870        with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"):
5871            gm = torch.fx.symbolic_trace(optimized)
5872
5873    @patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False)
5874    def test_no_error_on_nested_fx_trace(self):
5875        input = torch.rand(2, 3)
5876
5877        def f(x):
5878            x + x
5879
5880        real = f(input)
5881
5882        optimized = torch._dynamo.optimize("eager")(f)
5883        self.assertTrue(same(optimized(input), real))
5884
5885        # should not error
5886        gm = torch.fx.symbolic_trace(optimized)
5887        self.assertTrue(same(gm(input), real))
5888
5889    def test_not_dynamic_scope(self):
5890        def f(y):
5891            x = 1
5892
5893            def g():
5894                x = 2
5895                return lambda: x
5896
5897            return y + g()()
5898
5899        input = torch.zeros(1)
5900        real = f(input)
5901        optimized = torch._dynamo.optimize("eager")(f)
5902        opt = optimized(input)
5903        self.assertTrue(same(opt, real))
5904
5905    def test_inference_mode(self):
5906        @torch.inference_mode()
5907        def func(x, y):
5908            return x.add(1.0) + y
5909
5910        x = torch.ones(4, requires_grad=True)
5911        y = torch.ones(4, requires_grad=True)
5912        ref = func(x, y)
5913        opt_func = torch._dynamo.optimize("eager")(func)
5914
5915        x1 = torch.ones(4, requires_grad=True)
5916        res = opt_func(x1, y)
5917        self.assertTrue(same(ref, res))
5918        self.assertTrue(same(x, x1))
5919
5920    def test_if_cond_nn_mod1(self):
5921        class MockModule(torch.nn.Module):
5922            def __init__(self, output_relu=True):
5923                super().__init__()
5924                self.relu = torch.nn.ReLU() if output_relu else None
5925
5926            def forward(self, x):
5927                x = torch.sin(x)
5928                if self.relu:
5929                    x = self.relu(x)
5930                return x
5931
5932        model = MockModule()
5933        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
5934
5935        x = torch.rand(4)
5936        ref = model(x)
5937        res = opt_model(x)
5938        self.assertTrue(same(ref, res))
5939
5940        model = MockModule(output_relu=False)
5941        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
5942
5943        x = torch.rand(4)
5944        ref = model(x)
5945        res = opt_model(x)
5946        self.assertTrue(same(ref, res))
5947
5948    def test_if_cond_nn_mod2(self):
5949        class MockModule(torch.nn.Module):
5950            def __init__(self):
5951                super().__init__()
5952                self.layer = torch.nn.Sequential()
5953
5954            def forward(self, x):
5955                if self.layer:
5956                    return x + 1
5957                else:
5958                    return x - 1
5959
5960        model = MockModule()
5961        x = torch.rand(4)
5962        ref = model(x)
5963        opt_model = torch.compile(backend="eager")(model)
5964        res = opt_model(x)
5965        self.assertTrue(same(ref, res))
5966
5967    def test_if_cond_nn_mod3(self):
5968        def fn(x):
5969            if torch.nn.ModuleList():
5970                return x + 1
5971            else:
5972                return x - 1
5973
5974        x = torch.rand(4)
5975        ref = fn(x)
5976        opt_fn = torch.compile(backend="eager")(fn)
5977        res = opt_fn(x)
5978        self.assertTrue(same(ref, res))
5979
5980    def test_if_cond_user_defined_object(self):
5981        # obj.__bool__ is not existed
5982        class A:  # noqa: B903
5983            def __init__(self, x):
5984                self.x = x
5985
5986        # obj.__bool__ is function and returns bool type
5987        class B:
5988            def __init__(self, x):
5989                self.x = x
5990
5991            def __bool__(self):
5992                return self.x > 0
5993
5994        # obj.__bool__ is non-function
5995        class C:
5996            def __init__(self, x):
5997                self.x = x
5998                self.__bool__ = False
5999
6000        def fn(x, obj):
6001            if not obj:
6002                return x + 1
6003            else:
6004                return x - 1
6005
6006        x = torch.rand(4)
6007        cnts = torch._dynamo.testing.CompileCounter()
6008        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
6009        obj1 = A(0.5)
6010        obj2 = B(0.5)
6011        obj3 = B(-0.5)
6012        obj4 = C(0.5)
6013        for obj in [obj1, obj2, obj3, obj4, obj3, obj2]:
6014            ref = fn(x, obj)
6015            res = opt_fn(x, obj)
6016            self.assertTrue(same(ref, res))
6017        self.assertEqual(cnts.frame_count, 4)
6018
6019    def test_if_cond_user_defined_object2(self):
6020        # obj.__bool__ is function and returns non-bool type
6021        class MyObj:
6022            def __init__(self, x):
6023                self.x = x
6024
6025            def __bool__(self):
6026                self.x = 1.2
6027                return self.x
6028
6029        def fn(a, obj):
6030            if not obj:
6031                return a + obj.x
6032            else:
6033                return a - obj.x
6034
6035        x = torch.rand(4)
6036        obj = MyObj(0.5)
6037        opt_fn = torch._dynamo.optimize("eager")(fn)
6038        try:
6039            opt_fn(x, obj)
6040            self.assertFalse(True)
6041        except TypeError as e:
6042            self.assertIn("__bool__ should return bool, returned float", str(e))
6043
6044    def test_if_cond_user_defined_object3(self):
6045        # obj.__bool__ is not existed, but obj.__len__ exists
6046        class A:  # noqa: B903
6047            def __init__(self, x):
6048                self.x = x
6049
6050            def __len__(self):
6051                return len(self.x)
6052
6053        # obj.__bool__ takes precedence over obj.__len__
6054        class B:
6055            def __init__(self, x):
6056                self.x = x
6057
6058            def __bool__(self):
6059                return False
6060
6061            def __len__(self):
6062                return len(self.x)
6063
6064        def fn(x, obj):
6065            if not obj:
6066                return x + 1
6067            else:
6068                return x - 1
6069
6070        x = torch.rand(4)
6071        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
6072        obj1 = A([1, 2, 3])
6073        obj2 = A([])
6074        obj3 = B([1, 2, 3])
6075        obj4 = B([])
6076        for obj in [obj1, obj2, obj3, obj4]:
6077            ref = fn(x, obj)
6078            res = opt_fn(x, obj)
6079            self.assertTrue(same(ref, res))
6080
6081    def test_class_has_instancecheck_method(self):
6082        class A:
6083            pass
6084
6085        class ExampleMeta(type):
6086            def __instancecheck__(cls, instance):
6087                return True
6088
6089        class B(metaclass=ExampleMeta):
6090            pass
6091
6092        def fn(x, obj):
6093            if isinstance(obj, B):
6094                return x + 1
6095            else:
6096                return x - 1
6097
6098        x = torch.rand(4)
6099        obj = A()
6100        ref = fn(x, obj)
6101        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6102        res = opt_fn(x, obj)
6103        self.assertTrue(same(ref, res))
6104
6105    def test_torch_cuda_is_available(self):
6106        def fn(x):
6107            if torch.cuda.is_available():
6108                return x + 1
6109            else:
6110                return x - 1
6111
6112        x = torch.rand(4)
6113        ref = fn(x)
6114        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6115        res = opt_fn(x)
6116        self.assertTrue(same(ref, res))
6117
6118    def test_variable_tracker_recursively_contains(self):
6119        # VariableTracker.recursively_contains should be updated correctly when mutation happens
6120        def fn(x):
6121            data = [[None] * 3] * 3
6122            for i in range(3):
6123                if i == 0:
6124                    data[0][i] = x
6125                else:
6126                    data[0][i] = data[0][i - 1] + 1
6127            return data[0][-1]
6128
6129        x = torch.rand(4)
6130        ref = fn(x)
6131        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6132        res = opt_fn(x)
6133        self.assertTrue(same(ref, res))
6134
6135    @unittest.skipIf(not TEST_CUDA, "requires cuda")
6136    @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
6137    def test_torch_cudnn_is_acceptable(self):
6138        def fn(x):
6139            if torch.backends.cudnn.is_acceptable(tensor=x):
6140                return x + 1
6141            return x
6142
6143        x = torch.rand(4).cuda()
6144        ref = fn(x)
6145        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6146        res = opt_fn(x)
6147        self.assertTrue(same(ref, res))
6148
6149    @unittest.skipIf(not TEST_CUDA, "requires cuda")
6150    @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
6151    def test_torch_cudnn_is_acceptable_bad_inputs(self):
6152        def fn1(x):
6153            if torch.backends.cudnn.is_acceptable("invalid"):
6154                return x + 1
6155            return x
6156
6157        def fn2(x):
6158            if torch.backends.cudnn.is_acceptable(x, 3.14):
6159                return x + 1
6160            return x
6161
6162        with self.assertRaisesRegex(
6163            AssertionError, "Expect input to cudnn.is_acceptable to be a tensor"
6164        ):
6165            x1 = torch.rand(4).cuda()
6166            opt_fn1 = torch._dynamo.optimize("eager", nopython=True)(fn1)
6167            res1 = opt_fn1(x1)
6168
6169        with self.assertRaisesRegex(
6170            AssertionError, "Expect 1 input to cudnn.is_acceptable"
6171        ):
6172            x2 = torch.rand(4).cuda()
6173            opt_fn2 = torch._dynamo.optimize("eager", nopython=True)(fn2)
6174            res = opt_fn2(x2)
6175
6176    @unittest.skipIf(not TEST_CUDA, "requires cuda")
6177    def test_get_device(self):
6178        def fn(x, y):
6179            x = x + 1
6180            y = y + 1
6181            return x.get_device(), y.get_device()
6182
6183        x = torch.rand(4, device="cuda")
6184        y = torch.rand(4, device="cpu")
6185        ref = fn(x, y)
6186        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6187        res = opt_fn(x, y)
6188        self.assertTrue(same(ref, res))
6189
6190    def test_disable_flag(self):
6191        cnt = torch._dynamo.testing.CompileCounter()
6192
6193        with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}):
6194
6195            def fn(x, y):
6196                x = x + 1
6197                y = y + 1
6198
6199            opt_fn = torch._dynamo.optimize(cnt)
6200
6201        self.assertEqual(cnt.frame_count, 0)
6202
6203    def test_is_compiling(self):
6204        def f1():
6205            if torch._dynamo.is_compiling():
6206                return torch.ones(2, 2)
6207            else:
6208                return torch.zeros(2, 2)
6209
6210        def f2():
6211            if torch._utils.is_compiling():
6212                return torch.ones(2, 2)
6213            else:
6214                return torch.zeros(2, 2)
6215
6216        def f3():
6217            if torch.compiler.is_compiling():
6218                return torch.ones(2, 2)
6219            else:
6220                return torch.zeros(2, 2)
6221
6222        def f4():
6223            if torch.compiler.is_dynamo_compiling():
6224                return torch.ones(2, 2)
6225            else:
6226                return torch.zeros(2, 2)
6227
6228        for f in [f1, f2, f3, f4]:
6229            opt_f = torch._dynamo.optimize("eager")(f)
6230
6231            self.assertEqual(f(), torch.zeros(2, 2))
6232            self.assertEqual(opt_f(), torch.ones(2, 2))
6233
6234    def test_torch_generator_set_state(self):
6235        def fn():
6236            default_state = torch.default_generator.get_state()
6237            x = torch.rand([2, 3])
6238            if default_state.dtype != "float32":
6239                x = x * 2
6240            torch._dynamo.graph_break()
6241            torch.default_generator.set_state(default_state)
6242            y = torch.rand([2, 3])
6243            return x, y
6244
6245        opt_fn = torch._dynamo.optimize("eager")(fn)
6246        x, y = opt_fn()
6247        self.assertEqual(x, y * 2)
6248
6249    def test_torch_distributions_lazy_property(self):
6250        def fn(x):
6251            return torch.distributions.Categorical(probs=x).entropy()
6252
6253        opt_fn = torch._dynamo.optimize("eager")(fn)
6254        x = torch.rand([4, 4])
6255        self.assertEqual(opt_fn(x), fn(x))
6256
6257    def test_guard_failure_fn(self):
6258        def fn(x, y, k):
6259            x = x + 1
6260            y = y + 1
6261            return x * y * k
6262
6263        x = torch.tensor([0.5, 0.5])
6264        y = torch.tensor([1.0, 1.0])
6265
6266        guard_failure = None
6267
6268        def guard_failures(failure):
6269            nonlocal guard_failure
6270            guard_failure = failure
6271
6272        opt_fn = torch._dynamo.optimize(
6273            "eager", nopython=True, guard_fail_fn=guard_failures
6274        )(fn)
6275
6276        x2 = torch.tensor([0.5, 0.5, 1.0])
6277        y2 = torch.tensor([0.5, 0.5, 0.5])
6278
6279        opt_fn(x, y, 3)
6280        opt_fn(x2, y2, 5)
6281
6282        if (
6283            not torch._dynamo.config.specialize_int
6284            and not torch._dynamo.config.assume_static_by_default
6285        ):
6286            # we didn't actually test guard_failure_fn here but whatever,
6287            # nice to see no guard failure on the test
6288            self.assertTrue(guard_failure is None)
6289        else:
6290            self.assertTrue(guard_failure is not None)
6291
6292    def test_guard_failure_fn_shape_control(self):
6293        def fn(x, y):
6294            if x.shape[0] < 3:
6295                if y.shape[0] < 3:
6296                    return x * y
6297                else:
6298                    return x + y
6299            else:
6300                return -1
6301
6302        x = torch.randn([2, 2])
6303        y = torch.randn([2, 2])
6304
6305        guard_failure = None
6306
6307        def guard_failures(failure):
6308            nonlocal guard_failure
6309            guard_failure = failure
6310
6311        opt_fn = torch._dynamo.optimize(
6312            "eager", nopython=True, guard_fail_fn=guard_failures
6313        )(fn)
6314
6315        x2 = torch.randn([5, 5])
6316        y2 = torch.randn([5, 5])
6317
6318        opt_fn(x, y)
6319        opt_fn(x2, y2)
6320
6321        self.assertTrue(guard_failure is not None)
6322        first_guard_failure = guard_failure[0].partition("\n")[0]
6323        if torch._dynamo.config.assume_static_by_default:
6324            self.assertIn(
6325                """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
6326                first_guard_failure,
6327            )
6328        else:
6329            self.assertIn("""2 <= L['x'].size()[0] <= 2""", first_guard_failure)
6330
6331    def test_guard_failure_fn2(self):
6332        def fn(x, y):
6333            x = x + 1
6334            y = y + 1
6335            return x * y
6336
6337        x = torch.tensor([0.5, 0.5])
6338        y = torch.tensor([1.0, 1.0])
6339
6340        guard_failure = None
6341
6342        def guard_failures(failure):
6343            nonlocal guard_failure
6344            guard_failure = failure
6345
6346        opt_fn = torch._dynamo.optimize(
6347            "eager", nopython=True, guard_fail_fn=guard_failures
6348        )(fn)
6349
6350        x2 = torch.tensor([0.5, 0.5, 1.0])
6351        y2 = torch.tensor([0.5, 0.5, 0.5])
6352
6353        opt_fn(x, y)
6354        opt_fn(x2, y2)
6355
6356        if torch._dynamo.config.assume_static_by_default:
6357            self.assertIn(
6358                """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
6359                guard_failure[0],
6360            )
6361        else:
6362            self.assertTrue(guard_failure is None)
6363
6364    def test_guard_failure_fn_tensor_iter(self):
6365        def fn(x):
6366            for y in x:
6367                y.add_(1.0)
6368            return y
6369
6370        guard_failure = None
6371
6372        def guard_failures(failure):
6373            nonlocal guard_failure
6374            guard_failure = failure
6375
6376        opt_fn = torch._dynamo.optimize(
6377            "eager", nopython=True, guard_fail_fn=guard_failures
6378        )(fn)
6379
6380        args1 = torch.randn(10, 10)
6381        out = fn(args1)
6382        opt_out = opt_fn(args1)
6383        self.assertTrue(same(out, opt_out))
6384
6385        args2 = torch.randn(9, 10)
6386        out = fn(args2)
6387        opt_out = opt_fn(args2)
6388        self.assertTrue(same(out, opt_out))
6389
6390        # guard is expected for both static and dynamic shapes
6391        self.assertTrue(guard_failure is not None)
6392        self.assertIn(
6393            """len(L['x']) == 10""",
6394            guard_failure[0],
6395        )
6396
6397    def test_restore_graphstate(self):
6398        # This function does some guard accumulation,
6399        # and then rolls back due to control flow.
6400        # The idea is that if one were printing guards as they appear,
6401        # they would see this insert a guard that does not show up in the final set of
6402        # guards as we rolled back from it.
6403        def nested_fn(s):
6404            if x[0] < 10:
6405                return s * s
6406            return s
6407
6408        def fn(x, y):
6409            x = x + 1
6410            y = nested_fn(y)
6411            y = y + 10
6412            return x * y
6413
6414        all_guards = []
6415
6416        def guard_export_print(guards):
6417            nonlocal all_guards
6418            all_guards.extend(guards)
6419
6420        opt_fn = torch._dynamo.optimize("eager", guard_export_fn=guard_export_print)(fn)
6421
6422        x = torch.tensor([0.5, 0.5])
6423        y = torch.tensor([1.0, 1.0])
6424        opt_fn(x, y)
6425
6426        for guard in all_guards:
6427            # This guard was created
6428            self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents")
6429
6430    def test_call_parent_non_class_methods_from_child(self):
6431        class A:
6432            a = 4
6433
6434            def add(self, x):
6435                return x + 10
6436
6437            def mul(self, x):
6438                return x * 0.1
6439
6440        class B(A):
6441            coeff = 4
6442
6443            def add(self, x):
6444                return x + 20
6445
6446            @classmethod
6447            def cube(cls, x):
6448                return cls.coeff * x * x * x
6449
6450            def mul(self, x):
6451                return super().mul(x) * x * 0.2
6452
6453        class C(B):
6454            def add(self, x):
6455                b = super().cube(x)
6456                c = A.add(self, x)
6457                d = B.mul(self, x)
6458                e = super(B, self).add(x)
6459                f = super().a * x
6460                return b + c + d + e + f
6461
6462        x = torch.rand(4)
6463        fn = C().add
6464        ref = fn(x)
6465        cnt = torch._dynamo.testing.CompileCounter()
6466        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
6467        res = opt_fn(x)
6468        self.assertTrue(same(ref, res))
6469        self.assertEqual(cnt.frame_count, 1)
6470
6471        # Check recompilation
6472        A.a = 5
6473        ref = fn(x)
6474        res = opt_fn(x)
6475        self.assertTrue(same(ref, res))
6476        # Ensure that super guard checks are working as expected
6477        res = opt_fn(x)
6478        self.assertEqual(cnt.frame_count, 2)
6479
6480    def test_builder_for_class_with_metaclass(self):
6481        class ExampleMeta(type):
6482            pass
6483
6484        class MyClass(metaclass=ExampleMeta):
6485            pass
6486
6487        def fn(x, y):
6488            if isinstance(y, MyClass):
6489                return x + 1
6490            else:
6491                return x - 1
6492
6493        x = torch.rand([4, 4])
6494        y = MyClass()
6495        ref = fn(x, y)
6496        opt_fn = torch._dynamo.optimize("eager")(fn)
6497        res = opt_fn(x, y)
6498        self.assertTrue(same(ref, res))
6499
6500    def test_tuple_from_tuple_iter(self):
6501        def inner_fn(*args):
6502            acc = torch.ones(10, 10)
6503            for arg in args:
6504                acc.add_(arg)
6505
6506            return acc
6507
6508        @torch._dynamo.optimize("eager")
6509        def fn(inputs, params):
6510            y = tuple(inputs) + tuple(params)
6511            return inner_fn(*y)
6512
6513        inputs = [torch.randn(10, 10) for _ in range(3)]
6514
6515        fn(inputs, iter(tuple(inputs)))
6516
6517        def fn(params):
6518            y = tuple(params)
6519            return inner_fn(*y)
6520
6521        opt_fn = torch._dynamo.optimize("eager")(fn)
6522        inputs = [torch.randn(10, 10) for _ in range(3)]
6523        self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs)))))
6524
6525        # Force recompilation
6526        inputs = [torch.randn(10, 10) for _ in range(4)]
6527        self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs)))))
6528
6529    def test_torch_package_working_with_trace(self):
6530        # from torch._dynamo.test_case import run_tests
6531
6532        inputs = [torch.randn([2, 2]), torch.randn([2, 2])]
6533
6534        optimized_model = torch._dynamo.optimize(backend="eager")(
6535            MyPickledModule(torch.randn([2, 2]))
6536        )
6537        from torch import package
6538
6539        path = "/tmp/MyPickledModule.pt"
6540        package_name = "MyPickledModule"
6541        resource_name = "MyPickledModule.pkl"
6542
6543        model = MyPickledModule(torch.randn([2, 2]))
6544
6545        with package.PackageExporter(path) as exp:
6546            exp.extern("**")
6547            exp.save_pickle(package_name, resource_name, model)
6548
6549        imp = package.PackageImporter(path)
6550        loaded_model = imp.load_pickle(package_name, resource_name)
6551
6552        optimized_loaded_model = torch._dynamo.optimize("eager")(loaded_model)(*inputs)
6553
6554    def test_shape_and_tuple_equality(self):
6555        def fn(x, y, t):
6556            z = x * y
6557            if x.size() == t:
6558                return z.cos()
6559            return z.sin()
6560
6561        torch._dynamo.optimize("eager", nopython=True)(fn)(
6562            torch.randn([4, 4]), torch.randn([4, 4]), (4, 4)
6563        )
6564
6565    def test_int_list(self):
6566        # if assume_static_by_default == True: spec int list
6567        # otherwise: unspec int list
6568        def fn(x, y):
6569            return torch.sin(x + y[1] % 2)
6570
6571        x = torch.randn(6)
6572        cnt = torch._dynamo.testing.CompileCounter()
6573        opt_fn = torch._dynamo.optimize(cnt)(fn)
6574        for i in range(10, 25, 3):
6575            y = [i, i + 1, i + 2]
6576            ref = fn(x, y)
6577            res = opt_fn(x, y)
6578            self.assertTrue(same(ref, res))
6579        if torch._dynamo.config.assume_static_by_default:
6580            if torch._dynamo.config.automatic_dynamic_shapes:
6581                self.assertExpectedInline(cnt.frame_count, """2""")
6582            else:
6583                self.assertExpectedInline(cnt.frame_count, """5""")
6584        else:
6585            self.assertExpectedInline(cnt.frame_count, """1""")
6586
6587    def test_patched_builtin_functions(self):
6588        import builtins
6589
6590        # Cache the original builtin function ids
6591        torch._dynamo.trace_rules._builtin_function_ids()
6592
6593        class MyClass:
6594            pass
6595
6596        builtin_isinstance = builtins.isinstance
6597
6598        def patched_isinstance(obj, classinfo) -> bool:
6599            if builtin_isinstance(obj, MyClass):
6600                return False
6601            else:
6602                return builtin_isinstance(obj, classinfo)
6603
6604        def fn(x, y):
6605            if isinstance(y, MyClass):
6606                return x + 1
6607            else:
6608                return x - 1
6609
6610        x = torch.ones(2, 3)
6611        y = MyClass()
6612
6613        try:
6614            ref = fn(x, y)
6615            # Monkey patch builtin function
6616            builtins.isinstance = patched_isinstance
6617            opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
6618            res = opt_fn(x, y)
6619            self.assertTrue(same(ref, x + 1))
6620            self.assertTrue(same(res, x - 1))
6621        finally:
6622            builtins.isinstance = builtin_isinstance
6623
6624        # check recompilation because builtins is now unpatched
6625        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
6626        res = opt_fn(x, y)
6627        self.assertTrue(same(res, x + 1))
6628
6629    # specifically test for tensor.attribute -> torch.something()
6630    def test_real_imag_tensor_attribute(self):
6631        def fn(x, y):
6632            a = x.real
6633            b = x.imag
6634            return torch.mul(torch.add(a, y), b)
6635
6636        x_real = torch.rand((4, 4))
6637        x_imag = torch.rand((4, 4))
6638        x = torch.complex(x_real, x_imag)
6639        y = torch.rand((4, 4))
6640
6641        ref = fn(x, y)
6642        opt_fn = torch._dynamo.optimize("eager")(fn)
6643        res = opt_fn(x, y)
6644        self.assertTrue(same(ref, res))
6645
6646    def test_cast(self):
6647        from typing import cast
6648
6649        def fn(x):
6650            return cast(torch.Tensor, torch.add(x, 1.0))
6651
6652        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
6653
6654        ref = fn(torch.ones(2, 2))
6655        res = opt_fn(torch.ones(2, 2))
6656
6657        self.assertTrue(same(ref, res))
6658
6659    def test_T_tensor_attribute(self):
6660        def fn(x, y):
6661            a = x.T
6662            return torch.add(a, y)
6663
6664        x = torch.rand((4, 4))
6665        y = torch.rand((4, 4))
6666
6667        ref = fn(x, y)
6668        opt_fn = torch._dynamo.optimize("eager")(fn)
6669        res = opt_fn(x, y)
6670        self.assertTrue(same(ref, res))
6671
6672    def test_recursive_tensor_attribute(self):
6673        def fn(x, y):
6674            a = x.real.T
6675            b = x.imag
6676            return torch.mul(torch.add(a, y), b)
6677
6678        x_real = torch.rand((4, 4))
6679        x_imag = torch.rand((4, 4))
6680        x = torch.complex(x_real, x_imag)
6681        y = torch.rand((4, 4))
6682
6683        ref = fn(x, y)
6684        opt_fn = torch._dynamo.optimize("eager")(fn)
6685        res = opt_fn(x, y)
6686        self.assertTrue(same(ref, res))
6687
6688    def test_assigning_function_to_object_attribute(self):
6689        # user-defined functions which are object's attributes are not converted to bound methods
6690        def my_add(*args):
6691            a, b = args
6692            return a + b
6693
6694        class MyClass:
6695            def __init__(self, func):
6696                self.add = func
6697
6698        obj = MyClass(my_add)
6699
6700        def fn(x):
6701            return obj.add(x, 2)
6702
6703        x = torch.rand(2, 3)
6704        ref = fn(x)
6705        opt_fn = torch.compile(backend="eager")(fn)
6706        res = opt_fn(x)
6707        self.assertTrue(same(ref, res))
6708
6709    def test_assigning_function_to_class_attribute(self):
6710        # user-defined functions which are class's attributes are converted to bound methods
6711        def my_add(*args):
6712            obj, a, b = args
6713            return obj.x + a + b
6714
6715        class MyClass:
6716            add = my_add
6717
6718            def __init__(self, x):
6719                self.x = x
6720
6721        obj = MyClass(0.5)
6722
6723        def fn(x):
6724            return obj.add(x, 2)
6725
6726        x = torch.rand(2, 3)
6727        ref = fn(x)
6728        opt_fn = torch.compile(backend="eager")(fn)
6729        res = opt_fn(x)
6730        self.assertTrue(same(ref, res))
6731
6732    def test_tagging_tensors_simple(self):
6733        def foo(x, y):
6734            return x * y, x, y
6735
6736        a = torch.randn([3, 3])
6737        a.tag = "a"
6738        a.frog = "ribbity ribbit"
6739        b = torch.randn([3, 3])
6740        b.tag = "b"
6741        b.frog = "ribbit"
6742
6743        exported = torch._dynamo.export(foo)(a, b)
6744        out_graph = exported[0]
6745
6746        nodes = list(out_graph.graph.nodes)
6747        placeholders = [node for node in nodes if node.op == "placeholder"]
6748        all_tags = []
6749        all_frogs = []
6750        for placeholder in placeholders:
6751            if "tensor_dict" in placeholder.meta:
6752                all_tags.append(placeholder.meta["tensor_dict"]["tag"])
6753                all_frogs.append(placeholder.meta["tensor_dict"]["frog"])
6754
6755        self.assertEqual(all_tags, ["a", "b"])
6756        self.assertEqual(all_frogs, ["ribbity ribbit", "ribbit"])
6757
6758    def test_tagging_tensors_mix_used_unused_structure(self):
6759        def pre_attention_state_ops(input, mems, state):
6760            lc_key = state[0]
6761            lc_val = state[1]
6762            bar = []
6763            for i in range(0, 4):
6764                bar2 = []
6765                for j in range(0, 3):
6766                    bar2.append(
6767                        lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
6768                    )
6769                bar.append(bar2)
6770
6771            return bar
6772
6773        mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
6774        state = [
6775            torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
6776            torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
6777        ]
6778        i = torch.tensor(
6779            [
6780                [0.0313, -0.1487, -0.3846, -0.5321],
6781                [-1.7073, 1.3331, -0.0890, -1.4935],
6782                [-0.8314, -0.1862, -0.5935, 1.5232],
6783            ]
6784        )
6785
6786        mems.tag = "MEMS"
6787        i.tag = "FOO"
6788        state[0].tag = "STATE_0"
6789        state[1].tag = "HMMM"
6790
6791        exported = torch._dynamo.export(pre_attention_state_ops)(i, mems, state)
6792        out_graph = exported[0]
6793
6794        nodes = list(out_graph.graph.nodes)
6795        placeholders = [node for node in nodes if node.op == "placeholder"]
6796        all_tags = []
6797        for placeholder in placeholders:
6798            if "tensor_dict" in placeholder.meta:
6799                all_tags.append(placeholder.meta["tensor_dict"]["tag"])
6800
6801        self.assertEqual(all_tags, ["STATE_0", "HMMM"])
6802
6803    def test_get_custom_tensor_attribute(self):
6804        def fn(x):
6805            return x.custom_attr * x
6806
6807        x = torch.rand((2, 2))
6808        x.custom_attr = 3.14
6809        ref = fn(x)
6810        opt_fn = torch._dynamo.optimize("eager")(fn)
6811        res = opt_fn(x)
6812        self.assertTrue(same(ref, res))
6813
6814    def test_set_custom_tensor_attribute(self):
6815        def fn(x):
6816            x.custom_attr = 3.14
6817            return x.custom_attr * x
6818
6819        x = torch.rand((2, 2))
6820        ref = fn(x)
6821        opt_fn = torch._dynamo.optimize("eager")(fn)
6822        res = opt_fn(x)
6823        self.assertTrue(same(ref, res))
6824
6825    def test_unhandled_exception_in_dynamo(self):
6826        # traceback.format_exc() approximates an unhandled exception
6827        def f(a):
6828            a += 1
6829            raise RuntimeError("smoge")
6830            return a
6831
6832        opt_fn = torch._dynamo.optimize("eager")(f)
6833        try:
6834            opt_fn(torch.ones(2))
6835        except RuntimeError as e:
6836            self.assertIn("smoge", traceback.format_exc())
6837
6838    def test_unhandled_exception_in_dynamo2(self):
6839        # segfaults in python 3.11 if shadow frame is freed improperly
6840        from torch.testing import make_tensor
6841
6842        def fn():
6843            # test that the errors are the same for dense and sparse versions
6844            def test1(*, is_sparse):
6845                # shapes must be compatible for matrix multiplication
6846                a = make_tensor((2, 3), dtype=torch.float32, device="cpu")
6847                if is_sparse:
6848                    a_sparse = a.to_sparse_csr()
6849                    return torch.addmm(a, a_sparse, a)
6850                else:
6851                    return torch.addmm(a, a, a)
6852
6853            try:
6854                test1(is_sparse=False)
6855            except RuntimeError as msg:
6856                try:
6857                    test1(is_sparse=True)
6858                except RuntimeError as msg2:
6859                    raise RuntimeError("smoge")
6860
6861        opt_fn = torch._dynamo.optimize("eager")(fn)
6862        try:
6863            opt_fn()
6864        except RuntimeError:
6865            self.assertIn("smoge", traceback.format_exc())
6866
6867    def test_variable_access_in_exception(self):
6868        def fn():
6869            x = torch.ones(1)
6870            try:
6871                raise RuntimeError("bad")
6872            except RuntimeError:
6873                x += 1
6874            return x
6875
6876        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6877        self.assertEqual(opt_fn(), torch.tensor([2.0]))
6878
6879    def test_nested_sequential_with(self):
6880        def fn(x):
6881            with torch.set_grad_enabled(True):
6882                with torch.set_grad_enabled(False):
6883                    x = x + 1
6884                with torch.set_grad_enabled(True):
6885                    x = x + 1
6886                return x
6887
6888        opt_fn = torch._dynamo.optimize("eager")(fn)
6889        self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
6890
6891    def test_nested_sequential_try(self):
6892        def fn(x):
6893            try:
6894                try:
6895                    x = x + 1
6896                except:
6897                    pass
6898                try:
6899                    try:
6900                        x = x + 1
6901                    except:
6902                        pass
6903                except:
6904                    pass
6905            except:
6906                pass
6907            return x
6908
6909        opt_fn = torch._dynamo.optimize("eager")(fn)
6910        self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
6911
6912    def test_nested_sequential_try_with(self):
6913        def fn(x):
6914            with torch.set_grad_enabled(True):
6915                try:
6916                    x = x + 1
6917                except:
6918                    pass
6919                try:
6920                    with torch.set_grad_enabled(False):
6921                        x = x + 1
6922                except:
6923                    pass
6924            return x
6925
6926        opt_fn = torch._dynamo.optimize("eager")(fn)
6927        self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
6928
6929    def test_nested_sequential_try_with_graph_break(self):
6930        def fn(x, n):
6931            with torch.set_grad_enabled(True):
6932                with torch.set_grad_enabled(False):
6933                    x = x + 1
6934                    torch._dynamo.graph_break()
6935                try:
6936                    with torch.set_grad_enabled(False):
6937                        x = x + 1
6938                        if n == 0:
6939                            torch._dynamo.graph_break()
6940                except:
6941                    pass
6942                with torch.set_grad_enabled(False):
6943                    x = x + 1
6944                    torch._dynamo.graph_break()
6945                x = x + 1
6946            return x
6947
6948        counter = CompileCounter()
6949        opt_fn = torch._dynamo.optimize(counter)(fn)
6950        self.assertEqual(opt_fn(torch.ones(1), 0), torch.tensor([5.0]))
6951        self.assertEqual(counter.frame_count, 1)
6952
6953        torch._dynamo.reset()
6954        counter = CompileCounter()
6955        opt_fn = torch._dynamo.optimize(counter)(fn)
6956        self.assertEqual(opt_fn(torch.ones(1), 1), torch.tensor([5.0]))
6957        self.assertEqual(counter.frame_count, 3)
6958
6959    def test_ordered_dict_alias_reconstruct(self):
6960        od = collections.OrderedDict
6961
6962        def fn():
6963            d1 = dict()
6964            d1["a"] = 1
6965            d2 = od(d1)
6966            d2["b"] = 2
6967            torch._dynamo.graph_break()
6968            if isinstance(d2, od):
6969                return d2["a"] + d2["b"]
6970            else:
6971                return 0
6972
6973        dis.dis(fn)
6974        self.assertEqual(torch._dynamo.optimize("eager")(fn)(), 3)
6975
6976    # NOTE this test can be removed once multiline errors are in Python.
6977    # See https://github.com/python/cpython/issues/106922
6978    @skipIfNotPy311
6979    def test_get_instruction_source_311(self):
6980        def f():
6981            # flake8: noqa
6982            # fmt: off
6983            # test binary ops
6984            a = ( b   )   +   c
6985            a = (a + b) // (c - d)
6986            a = b    \
6987         +\
6988               c  # test
6989            a = (
6990                (b  # test +
6991                    )  \
6992                # +
6993            << (
6994
6995                c  # test
6996                \
6997            )  # test
6998            )
6999
7000            # test slice
7001            a = bbb   [  ccc    ]
7002            b = bbbbb \
7003                [  ccc # test
7004
7005                 + ddd  \
7006
7007                ] # test
7008            a = bbb[ccc][ddd][eee]
7009
7010            # test nested and multiline function calls
7011            a = g(g(g(b)))
7012            a = g(h(
7013                g(b),
7014                c
7015            ))
7016
7017            # test chained function calls
7018            a = (g(x).y)(
7019                z
7020            )(1)(2)
7021
7022            # test unicode (match traceback behavior)
7023            a = ("������" +
7024                + "����") + b
7025
7026        from torch._dynamo.utils import get_instruction_source_311
7027
7028        if sys.version_info >= (3, 12):
7029            # Offsets changed in 3.12, e.g. due to removal of PRECALL inst
7030            offsets = (3, 11, 15, 19, 23, 29, 35, 44, 53, 65)
7031        else:
7032            offsets = (3, 11, 15, 19, 23, 29, 35, 46, 58, 74)
7033        insts = list(dis.get_instructions(f))
7034        # uncomment to determine offsets
7035        # print(*enumerate(insts), sep="\n")
7036        all_sources = "\n".join(
7037            get_instruction_source_311(f.__code__, insts[offset]) for offset in offsets
7038        )
7039        self.assertExpectedInline(
7040            all_sources,
7041            """\
7042            a = ( b   )   +   c
7043                ~~~~~~~~~~^~~~~
7044
7045            a = (a + b) // (c - d)
7046                ~~~~~~~~^^~~~~~~~~
7047
7048            a = b    \\
7049                ~~~~~~
7050         +\\
7051         ^~
7052               c  # test
7053               ~
7054
7055                (b  # test +
7056                ~~~~~~~~~~~~
7057                    )  \\
7058                    ~~~~
7059                # +
7060                ~~~
7061            << (
7062            ^^~~
7063
7064
7065                c  # test
7066                ~~~~~~~~~
7067                \\
7068                ~
7069            )  # test
7070            ~
7071
7072            a = bbb   [  ccc    ]
7073                ~~~~~~^^^^^^^^^^^
7074
7075            b = bbbbb \\
7076                ~~~~~~~
7077                [  ccc # test
7078                ^^^^^^^^^^^^^
7079
7080
7081                 + ddd  \\
7082                 ^^^^^^^^
7083
7084
7085                ] # test
7086                ^
7087
7088            a = bbb[ccc][ddd][eee]
7089                ~~~~~~~~^^^^^
7090
7091            a = g(g(g(b)))
7092                  ~^^^^^^
7093
7094            a = g(h(
7095                  ~^
7096                g(b),
7097                ^^^^^
7098                c
7099                ^
7100            ))
7101            ^
7102
7103            a = (g(x).y)(
7104                ~~~~~~~~~
7105                z
7106                ~
7107            )(1)(2)
7108            ~^^^
7109""",
7110        )
7111        # test unicode (since assertExpectedInline doesn't support unicode)
7112        op_offset = 74 if sys.version_info >= (3, 12) else 84
7113        self.assertEqual(
7114            get_instruction_source_311(f.__code__, insts[op_offset]),
7115            """\
7116            a = ("������" +
7117                ~~~~~~~~
7118                + "����") + b
7119                ~~~~~~~~^~~
7120""",
7121        )
7122
7123    def test_raise_guard_full_constraint(self):
7124        y = torch.randn([3, 3, 3])
7125
7126        def my_dyn_fn(x):
7127            if x.shape[0] == 3:
7128                return x.sin()
7129            return x.cos()
7130
7131        torch._dynamo.mark_dynamic(y, 0)
7132        with self.assertRaises(ConstraintViolationError):
7133            torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7134
7135    # Translation validation changes the exception type, don't run with it
7136    @torch.fx.experimental._config.patch(translation_validation=False)
7137    def test_mark_dynamic_with_ranges(self):
7138        y = torch.randn([8, 3, 3])
7139
7140        def my_dyn_fn(x):
7141            if x.shape[0] == 3:
7142                return x.sin()
7143            return x.cos()
7144
7145        torch._dynamo.mark_dynamic(y, 0, min=2, max=5)
7146        with self.assertRaises(ConstraintViolationError):
7147            torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7148
7149    def test_mark_static(self):
7150        counter = CompileCounter()
7151
7152        def my_dyn_fn(x):
7153            return x.cos()
7154
7155        y = torch.randn([3])
7156        torch._dynamo.mark_static(y, 0)
7157        torch._dynamo.optimize(counter)(my_dyn_fn)(y)
7158
7159        z = torch.randn([4])
7160        torch._dynamo.optimize(counter)(my_dyn_fn)(z)
7161
7162        self.assertEqual(counter.frame_count, 2)
7163
7164    def test_no_raise_guard_partial_constraint(self):
7165        y = torch.randn([3, 3, 3])
7166
7167        def my_dyn_fn(x):
7168            if x.shape[0] > 3:
7169                return x.sin()
7170            return x.cos()
7171
7172        torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7173        torch._dynamo.mark_dynamic(y, 0)
7174        torch._dynamo.reset()
7175        torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7176
7177    def test_no_raise_guard_partial_constraint_across_break(self):
7178        y = torch.randn([3, 3, 3])
7179
7180        def my_dyn_fn(x, y):
7181            z = x * y
7182
7183            torch._dynamo.graph_break()
7184            if z.shape[0] > 2:
7185                return z.cos()
7186
7187            return x.cos()
7188
7189        torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7190        torch._dynamo.mark_dynamic(y, 0)
7191        torch._dynamo.reset()
7192        torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7193
7194    # Sadly, this does not throw - we do not prop correctly across the graph break
7195    @unittest.expectedFailure
7196    def test_raise_guard_partial_constraint_across_break(self):
7197        y = torch.randn([3, 3, 3])
7198
7199        def my_dyn_fn(x, y):
7200            z = x * y
7201
7202            torch._dynamo.graph_break()
7203            if z.shape[0] == 3:
7204                return z.cos()
7205
7206            return x.cos()
7207
7208        torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7209        torch._dynamo.mark_dynamic(y, 0)
7210        torch._dynamo.reset()
7211        with self.assertRaisesRegex(
7212            Exception,
7213        ):
7214            torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7215
7216    def test_raise_guard_partial_constraint_no_graph_break(self):
7217        y = torch.randn([3, 3, 3])
7218
7219        def my_dyn_fn(x, y):
7220            z = x * y
7221
7222            if z.shape[0] == 3:
7223                return z.cos()
7224
7225            return x.cos()
7226
7227        torch._dynamo.mark_dynamic(y, 0)
7228        with self.assertRaises(ConstraintViolationError):
7229            torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7230
7231    def test_cannot_trace_mark_dynamic(self):
7232        y = torch.randn([3, 3, 3])
7233
7234        def my_dyn_fn(x):
7235            torch._dynamo.mark_dynamic(x, 0)
7236            return x * x
7237
7238        with self.assertRaisesRegex(
7239            AssertionError, "Attempt to trace forbidden callable"
7240        ):
7241            torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7242
7243    def test_cannot_trace_mark_dynamic_safe_unreached(self):
7244        y = torch.randn([3, 3, 3])
7245
7246        def my_dyn_fn(x):
7247            if x.shape[0] == 3:
7248                return x
7249            print("Running", torch._dynamo.mark_dynamic(x, 0))
7250            return x * x
7251
7252        torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7253
7254    def test_anomaly_aot_autograd(self):
7255        def fail():
7256            raise AssertionError("fail")
7257
7258        @allow_in_graph
7259        def h(a):
7260            r = a.sum()
7261            # Trigger an exception in backwards
7262            r.register_hook(lambda x: fail())
7263            return r
7264
7265        @torch.compile(backend="aot_eager")
7266        def f(a):
7267            return h(a)
7268
7269        with warnings.catch_warnings(record=True) as w, self.assertRaises(
7270            torch._dynamo.exc.BackendCompilerFailed
7271        ):
7272            f(torch.randn(2, 2, requires_grad=True))
7273
7274        # Suppress unrelated pkg_resources warnings
7275        self.assertIn("forward call that caused the error", str(w[-1].message))
7276
7277    def test_py_guards_mark_dynamic(self):
7278        def my_dyn_fn(a):
7279            if a.shape[0] > 2:
7280                return a.cos()
7281            return a.sin()
7282
7283        counter = CompileCounter()
7284
7285        # Run with dynamic
7286        x0 = torch.randn([3, 3, 3])
7287        torch._dynamo.mark_dynamic(x0, 0)
7288        torch._dynamo.optimize(counter)(my_dyn_fn)(x0)
7289        self.assertEqual(counter.frame_count, 1)
7290
7291        # Run without dynamic, no recompile
7292        x = torch.randn([3, 3, 3])
7293        torch._dynamo.optimize(counter)(my_dyn_fn)(x)
7294        self.assertEqual(counter.frame_count, 1)
7295
7296        # Mark a new dim, 1, as dynamic
7297        x1 = torch.randn([3, 3, 3])
7298        torch._dynamo.mark_dynamic(x1, 1)
7299        torch._dynamo.optimize(counter)(my_dyn_fn)(x1)
7300        # Recompile triggered because we marked a new dym as dynamic
7301        self.assertEqual(counter.frame_count, 2)
7302
7303        # Reset
7304        torch._dynamo.reset()
7305        # Reset counter
7306        counter = CompileCounter()
7307
7308        # Run with dynamic 1
7309        torch._dynamo.optimize(counter)(my_dyn_fn)(x1)
7310        self.assertEqual(counter.frame_count, 1)
7311
7312        # Run with dynamic 0, not subset
7313        torch._dynamo.optimize(counter)(my_dyn_fn)(x0)
7314        self.assertEqual(counter.frame_count, 2)
7315
7316        # Run with dynamic 0, 1, 2, not subset
7317        x012 = torch.randn([3, 3, 3])
7318        torch._dynamo.mark_dynamic(x012, 0)
7319        torch._dynamo.mark_dynamic(x012, 1)
7320        torch._dynamo.mark_dynamic(x012, 2)
7321        torch._dynamo.optimize(counter)(my_dyn_fn)(x012)
7322        self.assertEqual(counter.frame_count, 3)
7323
7324    def test_recompile_on_global_state_change(self):
7325        last_state = []
7326        cnt = 0
7327
7328        def my_compiler(gm, _):
7329            nonlocal cnt
7330            cnt += 1
7331            state = read_state()
7332
7333            def inner(*args):
7334                last_state[:] = state
7335                return gm(*args)
7336
7337            return inner
7338
7339        def read_state():
7340            return [
7341                torch.is_grad_enabled(),
7342                torch.are_deterministic_algorithms_enabled(),
7343                torch._C._get_cublas_allow_tf32(),
7344            ]
7345
7346        def write_state(state):
7347            torch.set_grad_enabled(state[0]),
7348            torch.use_deterministic_algorithms(state[1])
7349            torch._C._set_cublas_allow_tf32(state[2]),
7350
7351        @torch.compile(backend=my_compiler)
7352        def fn(x):
7353            return x + 1
7354
7355        initial_state = read_state()
7356        y = torch.randn(10)
7357        try:
7358            for round in range(3):
7359                for i in range(len(initial_state)):
7360                    new_state = [False] * len(initial_state)
7361                    new_state[i] = True
7362                    write_state(new_state)
7363                    assert read_state() == new_state
7364                    last_state.clear()
7365                    fn(y)
7366                    assert last_state == new_state
7367                    if round == 0:
7368                        assert cnt == i + 1
7369                    else:
7370                        assert cnt == len(initial_state)
7371        finally:
7372            write_state(initial_state)
7373
7374    def test_grad_state_mutated(self):
7375        prior = torch.is_grad_enabled()
7376        value = None
7377        cnt = CompileCounter()
7378
7379        @torch._dynamo.allow_in_graph
7380        def check_state():
7381            nonlocal value
7382            value = torch.is_grad_enabled()
7383
7384        @torch.compile(backend=cnt, fullgraph=True)
7385        def fn(x):
7386            check_state()
7387            torch.set_grad_enabled(False)
7388            return x + 1
7389
7390        try:
7391            torch.set_grad_enabled(True)
7392            fn(torch.randn(10))
7393            assert value is True
7394            assert torch.is_grad_enabled() is False
7395
7396            value = None
7397            torch.set_grad_enabled(True)
7398            fn(torch.randn(10))
7399            assert value is True
7400            assert torch.is_grad_enabled() is False
7401
7402            assert cnt.frame_count == 1
7403        finally:
7404            torch.set_grad_enabled(prior)
7405
7406    def test_deterministic_algorithms_mutated(self):
7407        prior = torch.are_deterministic_algorithms_enabled()
7408        prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
7409        value = None
7410        warn_only = None
7411        cnt = CompileCounter()
7412
7413        @torch._dynamo.allow_in_graph
7414        def check_state():
7415            nonlocal value
7416            nonlocal warn_only
7417            value = torch.are_deterministic_algorithms_enabled()
7418            warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
7419
7420        @torch.compile(backend=cnt, fullgraph=True)
7421        def fn(x):
7422            check_state()
7423            torch.use_deterministic_algorithms(False, warn_only=False)
7424            return x + 1
7425
7426        def run_fn():
7427            torch.use_deterministic_algorithms(True, warn_only=True)
7428            fn(torch.randn(10))
7429            assert value is True
7430            assert warn_only is True
7431            assert torch.are_deterministic_algorithms_enabled() is False
7432            assert torch.is_deterministic_algorithms_warn_only_enabled() is False
7433
7434        try:
7435            run_fn()
7436            value, warn_only = None, None
7437            run_fn()
7438            assert cnt.frame_count == 1
7439        finally:
7440            torch.use_deterministic_algorithms(prior, warn_only=prior_warn_only)
7441
7442    def test_torch_compile_ctx_on_forward_and_training_step(self):
7443        class MyModel(torch.nn.Module):
7444            def forward(self):
7445                ...
7446
7447            def training_step(self):
7448                self()
7449
7450        model = MyModel()
7451        compiled_model = torch.compile(model)
7452
7453        model.forward = compiled_model.dynamo_ctx(model.forward)
7454        model.training_step = compiled_model.dynamo_ctx(model.training_step)
7455
7456        model.training_step()
7457
7458    def test_torch_guards_stack_frame_register_inlining(self):
7459        x = torch.tensor([0.5, 0.5])
7460        y = torch.tensor([0.75, 0.75, 0.75, 0.75])
7461        z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25])
7462
7463        def uwu_inline_me(x, y, z):
7464            r = torch.cat((x, x)) + y
7465            r2 = torch.cat((y, y)) + z
7466            return r, r2
7467
7468        def fn(x, y, z):
7469            r, r2 = uwu_inline_me(x, y, z)
7470            return torch.mul(r, r), torch.mul(r2, r2)
7471
7472        seen_frames = []
7473        import contextlib
7474
7475        @contextlib.contextmanager
7476        def global_context_capture_fn(frame_summary):
7477            if frame_summary is not None:
7478                seen_frames.append(frame_summary)
7479            yield
7480
7481        with mock.patch(
7482            "torch._guards.TracingContext.current_frame",
7483            side_effect=global_context_capture_fn,
7484        ):
7485            torch._dynamo.optimize("eager")(fn)(x, y, z)
7486
7487        self.assertEqual(len(seen_frames), 1)
7488        self.assertEqual(seen_frames[0].name, "fn")
7489        self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
7490
7491    def test_torch_guards_stack_frame_register_inlining_deep(self):
7492        x = torch.tensor([0.5, 0.5])
7493        y = torch.tensor([0.75, 0.75, 0.75, 0.75])
7494        z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25])
7495
7496        def uwu_inline_me_deep(x, y):
7497            return torch.cat((x, x)) + y
7498
7499        def uwu_inline_me(x, y, z):
7500            r = uwu_inline_me_deep(x, y)
7501            r2 = uwu_inline_me_deep(y, z)
7502            return r, r2
7503
7504        def fn(x, y, z):
7505            r, r2 = uwu_inline_me(x, y, z)
7506            return torch.mul(r, r), torch.mul(r2, r2)
7507
7508        seen_frames = []
7509        import contextlib
7510
7511        @contextlib.contextmanager
7512        def global_context_capture_fn(frame_summary):
7513            if frame_summary is not None:
7514                seen_frames.append(frame_summary)
7515            yield
7516
7517        with mock.patch(
7518            "torch._guards.TracingContext.current_frame",
7519            side_effect=global_context_capture_fn,
7520        ):
7521            torch._dynamo.optimize("eager")(fn)(x, y, z)
7522
7523        self.assertEqual(len(seen_frames), 3)
7524        self.assertEqual(seen_frames[0].name, "fn")
7525        self.assertEqual(seen_frames[1].name, "uwu_inline_me")
7526        self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)")
7527
7528    def test_error_on_recompile(self):
7529        @torch._dynamo.optimize("eager")
7530        def fn(a, b):
7531            return a + b
7532
7533        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
7534            with self.assertRaises(torch._dynamo.exc.RecompileError):
7535                fn(torch.rand(2, 3), torch.rand(2, 3))
7536                fn(torch.rand(2, 3), (1, 2, 3))
7537
7538    @expectedFailureDynamic
7539    @torch._dynamo.config.patch(automatic_dynamic_shapes=False)
7540    def test_compile_profiler(self):
7541        class Model(torch.nn.Module):
7542            def forward(self, input):
7543                return input + input
7544
7545        model = Model()
7546        prof = CompileProfiler()
7547        compiled = torch.compile(model, backend=prof)
7548        base_checker = (
7549            lambda: FileCheck()
7550            .check("Torchdynamo Profiler Report")
7551            .check("Graph Breaks")
7552            .check("No graph breaks detected.")
7553            .check("Recompilation")
7554        )
7555        input = torch.rand((2, 3, 4))
7556        _ = compiled(input)
7557        base_checker().check("No recompilation detected.").run(prof.report())
7558
7559        new_shape_input = torch.rand((3, 3, 4))
7560        _ = compiled(new_shape_input)
7561
7562        # Not an exhaustive test of dynamic shapes behavior, but some sanity
7563        if torch._dynamo.config.assume_static_by_default:
7564            base_checker().check("Recompile Reasons").check("'forward'").check(
7565                "cache_size_limit to 1"
7566            ).run(prof.report())
7567        else:
7568            base_checker().check("No recompilation detected.").run(prof.report())
7569
7570        new_shape_input = torch.rand((4, 3, 4))
7571        _ = compiled(new_shape_input)
7572
7573        base_checker().check("Recompile Reasons").check("'forward'").check(
7574            "tensor 'L['input']' size mismatch at index 0. expected 2, actual 3"
7575        ).check(
7576            "tensor 'L['input']' size mismatch at index 0. expected 3, actual 4"
7577        ).run(
7578            prof.report()
7579        )
7580
7581    def test_guards_strip_function_call(self):
7582        from torch._dynamo.guards import strip_function_call
7583
7584        test_case = [
7585            ("___odict_getitem(a, 1)", "a"),
7586            ("a.layers[slice(2)][0]._xyz", "a"),
7587            ("getattr(a.layers[slice(2)][0]._abc, '0')", "a"),
7588            ("getattr(getattr(a.x[3], '0'), '3')", "a"),
7589            ("a.layers[slice(None, -1, None)][0]._xyz", "a"),
7590            ("a.layers[func('offset', -1, None)][0]._xyz", "a"),
7591        ]
7592        # strip_function_call should extract the object from the string.
7593        for name, expect_obj in test_case:
7594            self.assertEqual(strip_function_call(name), expect_obj)
7595
7596    def test_int_neg(self):
7597        def int_neg(a, b):
7598            x = a.shape[0]
7599            y = b.shape[0]
7600            return -x * -y * a * b
7601
7602        torch._dynamo.testing.standard_test(self, int_neg, 2)
7603
7604    def test_hash_getitem_slice(self):
7605        s = GetItemSource(LocalSource("foo"), slice(None, -1, None))
7606        s2 = GetItemSource(LocalSource("foo"), slice(None, -1, None))
7607        s3 = GetItemSource(LocalSource("foo"), slice(None, -1, 2))
7608        some_set = set()
7609
7610        self.assertTrue(s not in some_set)
7611        self.assertTrue(s2 not in some_set)
7612        self.assertTrue(s3 not in some_set)
7613
7614        some_set.add(s)
7615
7616        self.assertTrue(s in some_set)
7617        # s and s2 should hash the  same
7618        self.assertTrue(s2 in some_set)
7619        # s3 should be different
7620        self.assertTrue(s3 not in some_set)
7621
7622        self.assertTrue(s == s2)
7623        self.assertTrue(s != s3)
7624
7625    def test_inline_dict_function(self):
7626        def _result_type_dict(dtype):
7627            return {bool: torch.float32}[dtype]
7628
7629        @torch.compile
7630        def f():
7631            return torch.ones(3, dtype=_result_type_dict(bool))
7632
7633        self.assertEqual(f(), torch.ones(3, dtype=torch.float32))
7634
7635    def test_inline_dict_function_passed_as_arg(self):
7636        @torch.compile
7637        def fn(d, x, y):
7638            if d[x] is torch.float32:
7639                return y.cos()
7640            else:
7641                return y.sin()
7642
7643        dd = {bool: torch.float32, int: torch.int64}
7644        self.assertEqual(fn(dd, bool, torch.ones(4)), torch.ones(4).cos())
7645        self.assertEqual(fn(dd, int, torch.ones(4)), torch.ones(4).sin())
7646
7647    def test_add_sizes(self):
7648        def func(x):
7649            y = x.size()
7650            return y + y
7651
7652        eager_out = func(torch.ones(10, 10, 3))
7653        compile_out = torch._dynamo.optimize("eager")(func)(torch.ones(10, 10, 3))
7654        self.assertTrue(isinstance(compile_out, torch.Size))
7655        self.assertEqual(eager_out, compile_out)
7656
7657    @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU")
7658    def test_cuda_set_device(self):
7659        def fn():
7660            a = torch.ones(2, device="cuda")
7661            torch.cuda.set_device(1)
7662            return a + 1
7663
7664        with torch.cuda.device(0):
7665            counter = CompileCounter()
7666            opt_fn = torch._dynamo.optimize(counter)(fn)
7667            res = opt_fn()
7668            self.assertEqual(res.device.type, "cuda")
7669            self.assertEqual(res.device.index, 0)
7670            self.assertEqual(counter.frame_count, 2)
7671
7672    def test_nested_function_resuming_with_correct_globals(self):
7673        # https://github.com/pytorch/pytorch/issues/99665
7674        try:
7675            from .utils import outer_func
7676        except ImportError:
7677            from utils import outer_func
7678
7679        def gn(x, y):
7680            return x + y
7681
7682        def fn(x, y):
7683            return outer_func(gn)(x, y)
7684
7685        x = torch.rand([3])
7686        y = torch.rand([3])
7687        opt_fn = torch.compile(backend="eager")(fn)
7688        ref = fn(x, y)
7689        res = opt_fn(x, y)
7690        self.assertTrue(same(ref, res))
7691
7692    @dataclasses.dataclass
7693    class CSETestCase:
7694        expr: str
7695        preface: typing.List[str] = dataclasses.field(default_factory=list)
7696        expected: typing.Optional[str] = None
7697        expected_py38: typing.Optional[str] = None
7698
7699    def _is_py38(self) -> bool:
7700        return sys.version_info[:2] <= (3, 8)
7701
7702    def _has_ast_unparse(self) -> bool:
7703        from torch._dynamo.guards import HAS_UNPARSE_FUNCTIONS
7704
7705        return HAS_UNPARSE_FUNCTIONS
7706
7707    def test_guards_cse_pass_single(self):
7708        if not self._has_ast_unparse():
7709            if IS_FBCODE:
7710                raise RuntimeError("Needs astunparse or Python-3.9+")
7711            raise unittest.SkipTest("Needs astunparse or Python-3.9+")
7712        from torch._dynamo.guards import PyExprCSEPass
7713
7714        testcase = self.CSETestCase
7715        testcases = [
7716            # Nothing gets CSE-d, since the only repeated sub-expression is 'x'.
7717            # i.e. not a node type we are interested on.
7718            testcase(expr="x[0].a"),
7719            testcase(expr="x[1].a"),
7720            testcase(expr="x[2].a"),
7721            # 'a.b.c' gets CSE-d, since it's a sub-expression used more than 'PyExprCSEPass.USE_THRESHOLD'.
7722            testcase(
7723                expr="a.b.c[0].d.e",
7724                preface=["_var0 = a.b", "_var1 = _var0.c"],
7725                expected="_var1[0].d.e",
7726            ),
7727            testcase(expr="a.b.c[1].d.e", expected="_var1[1].d.e"),
7728            testcase(expr="a.b.c[2].d.e", expected="_var1[2].d.e"),
7729            # 'm.n[0]' gets CSE-d, since it is a sub-expression used more than 'PyExprCSEPass.USE_THRESHOLD'.
7730            testcase(
7731                expr="f(m.n[0], '0').x.y.z",
7732                preface=["_var2 = m.n", "_var3 = _var2[0]"],
7733                expected="f(_var3, '0').x.y.z",
7734            ),
7735            testcase(expr="f(m.n[0], '1').x.y.z", expected="f(_var3, '1').x.y.z"),
7736            testcase(expr="f(m.n[0], '2').x.y.z", expected="f(_var3, '2').x.y.z"),
7737            # The whole expressiong gets CSE-d, as well as all of its sub-expressions.
7738            testcase(
7739                expr="self.g(a, b).k",
7740                preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"],
7741                expected="_var6",
7742            ),
7743            testcase(expr="self.g(a, b).k", expected="_var6"),
7744            testcase(expr="self.g(a, b).k", expected="_var6"),
7745        ]
7746        csepass = PyExprCSEPass()
7747        csepass.count([t.expr for t in testcases])
7748
7749        for t in testcases:
7750            preface, expr = csepass.replace(t.expr)
7751            self.assertEqual(preface, t.preface)
7752            expected = t.expected if t.expected is not None else t.expr
7753            self.assertEqual(expr, expected)
7754
7755    def test_guards_cse_pass_multiple(self):
7756        if not self._has_ast_unparse():
7757            raise unittest.SkipTest("Needs astunparse or Python-3.9+")
7758        from torch._dynamo.guards import PyExprCSEPass
7759
7760        testcase = self.CSETestCase
7761        testcases = [
7762            testcase(
7763                expr="x[0].a < x[1].a * (3 - x[2].a)",
7764                expected="x[0].a < x[1].a * (3 - x[2].a)",
7765                expected_py38="(x[0].a < (x[1].a * (3 - x[2].a)))",
7766            ),
7767            testcase(
7768                expr="a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0",
7769                preface=["_var0 = a.b", "_var1 = _var0.c"],
7770                expected="_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0",
7771                expected_py38="((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)",
7772            ),
7773            testcase(
7774                expr="f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512",
7775                preface=["_var2 = m.n", "_var3 = _var2[0]"],
7776                expected="f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512",
7777                expected_py38="(((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)",
7778            ),
7779            testcase(
7780                expr="self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k",
7781                preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"],
7782                expected="_var6 + (1 - _var6) <= m[0].a + _var6",
7783                expected_py38="((_var6 + (1 - _var6)) <= (m[0].a + _var6))",
7784            ),
7785        ]
7786
7787        csepass = PyExprCSEPass()
7788        csepass.count([t.expr for t in testcases])
7789
7790        for t in testcases:
7791            preface, expr = csepass.replace(t.expr)
7792            self.assertEqual(preface, t.preface)
7793            expected = t.expected_py38 if self._is_py38() else t.expected
7794            expected = expected if expected is not None else t.expr
7795            self.assertEqual(expr, expected)
7796
7797    def test_guard_function_builder_with_cse(self):
7798        from torch._dynamo.guards import build_guard_function
7799
7800        exprs = [
7801            "x[0].a < x[1].a * (3 - x[2].a)",
7802            "a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0",
7803            "f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512",
7804            "self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k",
7805        ]
7806
7807        _, pycode = build_guard_function(exprs, "")
7808        expected = """\
7809def ___make_guard_fn():
7810    def guard(L):
7811        if not (x[0].a < x[1].a * (3 - x[2].a)):
7812            return False
7813        _var0 = a.b
7814        _var1 = _var0.c
7815        if not (_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0):
7816            return False
7817        _var2 = m.n
7818        _var3 = _var2[0]
7819        if not (f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512):
7820            return False
7821        _var4 = self.g
7822        _var5 = _var4(a, b)
7823        _var6 = _var5.k
7824        if not (_var6 + (1 - _var6) <= m[0].a + _var6):
7825            return False
7826        return True
7827    return guard
7828"""
7829        expected_38 = """\
7830def ___make_guard_fn():
7831    def guard(L):
7832        if not ((x[0].a < (x[1].a * (3 - x[2].a)))):
7833            return False
7834        _var0 = a.b
7835        _var1 = _var0.c
7836        if not (((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)):
7837            return False
7838        _var2 = m.n
7839        _var3 = _var2[0]
7840        if not ((((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)):
7841            return False
7842        _var4 = self.g
7843        _var5 = _var4(a, b)
7844        _var6 = _var5.k
7845        if not (((_var6 + (1 - _var6)) <= (m[0].a + _var6))):
7846            return False
7847        return True
7848    return guard
7849"""
7850        expected_38_no_astunparse = """\
7851def ___make_guard_fn():
7852    def guard(L):
7853        if not (x[0].a < x[1].a * (3 - x[2].a)):
7854            return False
7855        if not (a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0):
7856            return False
7857        if not (f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512):
7858            return False
7859        if not (self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k):
7860            return False
7861        return True
7862    return guard
7863"""
7864
7865        if self._is_py38():
7866            expected = (
7867                expected_38 if self._has_ast_unparse() else expected_38_no_astunparse
7868            )
7869        self.assertEqual(expected, pycode)
7870
7871    def test_dynamo_compiling_fake_tensor_to_vararg_int(self):
7872        class MyModule(torch.nn.Module):
7873            def __init__(self):
7874                super().__init__()
7875
7876            def forward(self, x):
7877                # use numpy int so it's wrapped as fake tensor in dynamo
7878                shape = np.int_(16)
7879                # test shape as fake tensor, which param type is
7880                # Sequence[Union[_int, SymInt]]
7881                return x.reshape(shape)
7882
7883        x = torch.rand([4, 4])
7884        model = MyModule()
7885        orig_out = model(x)
7886        opt_model = torch._dynamo.optimize("eager")(MyModule())
7887        opt_out = opt_model(x)
7888        self.assertTrue(same(orig_out, opt_out))
7889
7890    def test_scalar_tensor_is_equivalent_to_symint_argument(self):
7891        class GumbelTopKSampler(torch.nn.Module):
7892            def __init__(self, T, k):
7893                super().__init__()
7894                self.T = torch.nn.Parameter(
7895                    torch.tensor(T, dtype=torch.float32), requires_grad=False
7896                )
7897                self.k = torch.nn.Parameter(
7898                    torch.tensor(k, dtype=torch.int32), requires_grad=False
7899                )
7900
7901            def sample_discrete(self, logits):
7902                threshold = torch.topk(logits, self.k, sorted=True)[0][..., -1]
7903                samples = torch.ge(logits.squeeze(1), threshold).float()
7904                return samples
7905
7906            def forward(self, logits):
7907                dsamples = self.sample_discrete(logits)
7908                return dsamples
7909
7910        x = torch.rand([4, 4, 4, 4])
7911        m = GumbelTopKSampler(T=4, k=4)
7912        orig_out = m(x)
7913        opt_m = torch.compile(backend="eager")(m)
7914        opt_out = opt_m(x)
7915        self.assertTrue(same(orig_out, opt_out))
7916
7917    def test_scalar_tensor_is_equivalent_to_symint_list_argument(self):
7918        class Jitter(torch.nn.Module):
7919            def __init__(self, jitter_val):
7920                super().__init__()
7921                self.jitter_val = jitter_val
7922
7923            def roll_tensor(self, input):
7924                h_shift = self.jitter_val - 1
7925                w_shift = self.jitter_val + 1
7926                return torch.roll(
7927                    torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3
7928                )
7929
7930            def forward(self, input):
7931                return self.roll_tensor(input)
7932
7933        x = torch.rand([4, 4, 4, 4])
7934        m = Jitter(jitter_val=4)
7935        orig_out = m(x)
7936        opt_m = torch.compile(backend="eager")(m)
7937        opt_out = opt_m(x)
7938        self.assertTrue(same(orig_out, opt_out))
7939
7940    def test_scalar_tensor_is_equivalent_to_int_list_argument(self):
7941        class MyModel(torch.nn.Module):
7942            def forward(self, input):
7943                permute = torch.tensor([0, 2, 1])
7944                x = input.permute(*permute)
7945                return x
7946
7947        x = torch.randn(2, 3, 4)
7948        m = MyModel()
7949        orig_out = m(x)
7950        opt_m = torch.compile(backend="eager")(m)
7951        opt_out = opt_m(x)
7952        self.assertTrue(same(orig_out, opt_out))
7953
7954    def test_torch_variable_hasattr(self):
7955        def fn(x):
7956            if hasattr(torch.nn, "Module"):
7957                return x * x
7958            return x + 1
7959
7960        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
7961
7962        x = torch.rand([4, 4])
7963        fn_out = fn(x)
7964        compiled_out = compiled_fn(x)
7965        self.assertTrue(same(fn_out, compiled_out))
7966
7967    def test_list_hasattr1(self):
7968        def fn(x):
7969            if hasattr(x, "foo"):
7970                return x[0] + 1
7971            return x[0] - 1
7972
7973        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
7974
7975        x = [torch.randn(3)]
7976        fn_out = fn(x)
7977        compiled_out = compiled_fn(x)
7978        self.assertTrue(same(fn_out, compiled_out))
7979
7980    def test_list_hasattr2(self):
7981        def fn():
7982            x = [torch.zeros(3)]
7983            if hasattr(x, "__len__"):
7984                return x[0] + 1
7985            return x[0] - 1
7986
7987        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
7988
7989        fn_out = fn()
7990        compiled_out = compiled_fn()
7991        self.assertTrue(same(fn_out, compiled_out))
7992
7993    def test_tuple_hasattr(self):
7994        def fn(x):
7995            if hasattr(x, "foo"):
7996                return x[0] + 1
7997            return x[1] - 1
7998
7999        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
8000
8001        x = (torch.randn(3), torch.randn(3))
8002        fn_out = fn(x)
8003        compiled_out = compiled_fn(x)
8004        self.assertTrue(same(fn_out, compiled_out))
8005
8006    def test_fn_hasattr__name__1(self):
8007        def fn():
8008            foo = lambda x: x + 1
8009            return hasattr(foo, "__name__")
8010
8011        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
8012
8013        fn_out = fn()
8014        compiled_out = compiled_fn()
8015        self.assertEqual(fn_out, compiled_out)
8016        self.assertTrue(fn_out)
8017
8018    def test_fn_hasattr__name__2(self):
8019        def bar(x):
8020            return torch.sin(x)
8021
8022        def fn():
8023            return hasattr(bar, "__name__")
8024
8025        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
8026
8027        fn_out = fn()
8028        compiled_out = compiled_fn()
8029        self.assertEqual(fn_out, compiled_out)
8030        self.assertTrue(fn_out)
8031
8032    def test_fn_hasattr__name__3(self):
8033        def bar(x, y):
8034            return torch.sin(x) + torch.cos(y)
8035
8036        baz = functools.partial(bar, y=4)
8037
8038        def fn():
8039            return hasattr(baz, "__name__")
8040
8041        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
8042
8043        fn_out = fn()
8044        compiled_out = compiled_fn()
8045        self.assertEqual(fn_out, compiled_out)
8046        self.assertFalse(fn_out)
8047
8048    def test_torch_objects_as_keys(self):
8049        remap = {torch.float16: torch.float32}
8050
8051        def fn():
8052            return torch.randn(3, dtype=remap[torch.float16])
8053
8054        opt = torch._dynamo.optimize("eager")(fn)
8055        opt()
8056
8057    def test_tracing_py_tree(self):
8058        def fn(xs):
8059            flat_xs, spec = pytree.tree_flatten(xs)
8060            res = [x.clone() for x in flat_xs]
8061            return pytree.tree_unflatten(res, spec)
8062
8063        xs = [torch.tensor(i) for i in range(3)]
8064
8065        counter = CompileCounter()
8066        torch._dynamo.optimize(counter, nopython=True)(fn)(xs)
8067        self.assertEqual(counter.frame_count, 1)
8068        self.assertEqual(counter.op_count, 3)
8069
8070    def test_tracing_nested_py_tree(self):
8071        import torch.utils._pytree as pytree
8072
8073        def fn(xs):
8074            flat_xs, spec = pytree.tree_flatten(xs)
8075            res = [x.clone() for x in flat_xs]
8076            return pytree.tree_unflatten(res, spec)
8077
8078        xs = [torch.tensor(i) for i in range(3)]
8079        xsl = [xs, xs, xs, xs]
8080
8081        counter = CompileCounter()
8082        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl)
8083        real_out = fn(xsl)
8084        self.assertEqual(comp_out, real_out)
8085        self.assertEqual(counter.frame_count, 1)
8086        self.assertEqual(counter.op_count, 12)
8087
8088    def test_tracing_nested_py_tree_tuples(self):
8089        import torch.utils._pytree as pytree
8090
8091        def fn(xs):
8092            flat_xs, spec = pytree.tree_flatten(xs)
8093            res = [x.clone() for x in flat_xs]
8094            return pytree.tree_unflatten(res, spec)
8095
8096        xs = [torch.tensor(i) for i in range(3)]
8097        xsl = (xs, xs, xs, xs)
8098
8099        counter = CompileCounter()
8100        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl)
8101        real_out = fn(xsl)
8102        self.assertEqual(comp_out, real_out)
8103        self.assertEqual(counter.frame_count, 1)
8104        self.assertEqual(counter.op_count, 12)
8105
8106    def test_tracing_nested_py_tree_dicts(self):
8107        import torch.utils._pytree as pytree
8108
8109        def fn(xs):
8110            flat_xs, spec = pytree.tree_flatten(xs)
8111            res = [x.clone() for x in flat_xs]
8112            return pytree.tree_unflatten(res, spec)
8113
8114        xs = [torch.tensor(i) for i in range(3)]
8115        xsl = {
8116            "a": xs,
8117            "b": xs,
8118            "c": xs,
8119        }
8120
8121        counter = CompileCounter()
8122        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl)
8123        real_out = fn(xsl)
8124        self.assertEqual(comp_out, real_out)
8125        self.assertEqual(counter.frame_count, 1)
8126        self.assertEqual(counter.op_count, 9)
8127
8128    def test_dynamic_one_hot(self):
8129        def fn(x):
8130            x = x + 1
8131            # graph break from data-dependent output shape
8132            x = torch.nn.functional.one_hot(x)
8133            x = x + 1
8134            return x
8135
8136        inp = torch.arange(20) % 4
8137        counter = CompileCounter()
8138        real_out = fn(inp)
8139        comp_out = torch.compile(fn, backend=counter)(inp)
8140        self.assertEqual(comp_out, real_out)
8141        self.assertEqual(counter.frame_count, 2)
8142        self.assertEqual(counter.op_count, 2)
8143
8144    def test_tracing_nested_py_tree_mixed_all(self):
8145        import torch.utils._pytree as pytree
8146
8147        def fn(xs):
8148            flat_xs, spec = pytree.tree_flatten(xs)
8149            res = [x.clone() for x in flat_xs]
8150            return pytree.tree_unflatten(res, spec)
8151
8152        xs = [torch.tensor(i) for i in range(3)]
8153        xsa = (xs, xs)
8154        xsb = {"aa": xsa, "ab": xs}
8155        xsl = {
8156            "a": xs,
8157            "b": xsa,
8158            "c": xsb,
8159        }
8160
8161        counter = CompileCounter()
8162        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl)
8163        real_out = fn(xsl)
8164        self.assertEqual(comp_out, real_out)
8165        self.assertEqual(counter.frame_count, 1)
8166        self.assertEqual(counter.op_count, 18)
8167
8168    def test_any_all_symnode(self):
8169        cnt = CompileCounter()
8170
8171        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
8172        def fn(x):
8173            t = x.size(0) >= 10
8174            f = x.size(0) >= 100
8175            if any([]) or any([f]) or any([f, f]):
8176                return x - 1
8177            if all([f]) or all([t, f]) or all([f, t]) or all([f, f]):
8178                return x - 2
8179            if not (all([]) and all([t]) and all([t, t])):
8180                return x - 3
8181            if not (any([t]) and any([t, f]) and any([f, t])):
8182                return x - 4
8183            return x + 1
8184
8185        y1 = torch.randn(16)
8186        y2 = torch.randn(18)
8187        self.assertEqual(fn(y1), y1 + 1)
8188        self.assertEqual(fn(y2), y2 + 1)
8189        self.assertEqual(cnt.frame_count, 1)
8190        y3 = torch.randn(5)
8191        self.assertEqual(fn(y3), y3 - 3)
8192        self.assertEqual(cnt.frame_count, 2)
8193
8194    def test_tracing_py_tree_tensor_subclass(self):
8195        import torch.utils._pytree as pytree
8196        from torch.testing._internal.two_tensor import TwoTensor
8197        from torch.utils.checkpoint import checkpoint
8198
8199        def fn(xs):
8200            nested_xs = [[xs]]
8201            flat_xs, spec = pytree.tree_flatten(xs)
8202            return flat_xs[0].clone()
8203
8204        # use checkpoint to trigger a "sourceless" tensor subclass
8205        def checkpoint_fn(xs):
8206            return checkpoint(fn, xs, use_reentrant=True)
8207
8208        xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2))
8209
8210        counter = CompileCounter()
8211        torch._dynamo.optimize(counter, nopython=True)(checkpoint_fn)(xs)
8212        self.assertEqual(counter.frame_count, 1)
8213        self.assertEqual(counter.op_count, 2)
8214
8215    def test_tracing_tree_map_only(self):
8216        import torch.utils._pytree as pytree
8217
8218        def fn(xs):
8219            def mapper(x):
8220                return x.clone()
8221
8222            y = pytree.tree_map_only(torch.Tensor, mapper, xs)
8223            return y
8224
8225        xs = [torch.tensor(i) for i in range(3)] + ["hi"]
8226        xsa = (xs, xs)
8227        xsb = {"aa": xsa, "ab": xs}
8228
8229        counter = CompileCounter()
8230        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsb)
8231        real_out = fn(xsb)
8232
8233        self.assertEqual(comp_out, real_out)
8234        self.assertEqual(counter.frame_count, 1)
8235        self.assertEqual(counter.op_count, 9)
8236
8237    @torch._dynamo.config.patch(
8238        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
8239    )
8240    def test_unbacked_symint(self):
8241        @torch.compile(backend="eager")
8242        def f(lengths, values):
8243            sizes = lengths.tolist()
8244            for s in sizes:
8245                torch._check_is_size(s)
8246                torch._check(s >= 2)
8247                torch._check(s <= 100)
8248            return torch.split(values, sizes)
8249
8250        f(torch.tensor([2, 3, 4]), torch.randn(9))
8251
8252    @torch._dynamo.config.patch(
8253        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
8254    )
8255    def test_unbacked_auto_functionalize_op(self):
8256        @torch.library.custom_op(
8257            "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"]
8258        )
8259        def mk_image(decoder: Tensor) -> Tensor:
8260            return torch.randn(2, 3, 4, 5)
8261
8262        @torch.library.register_fake("mylib::mk_image")
8263        def _(decoder: Tensor) -> Tensor:
8264            image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)]
8265            return torch.empty(image_size)
8266
8267        @torch.compile(fullgraph=True)
8268        def f(x):
8269            return torch.ops.mylib.mk_image.default(x)
8270
8271        x = torch.zeros(100, dtype=torch.int64)
8272        f(x)
8273
8274    @torch._dynamo.config.patch(capture_scalar_outputs=True)
8275    def test_runtime_assert_replacement(self):
8276        @torch.compile(backend="aot_eager")
8277        def fn(x, y):
8278            z = y.item()
8279            torch._check(z == 3)
8280            return x + z
8281
8282        fn(torch.randn(4), torch.tensor([3]))
8283        self.assertRaises(RuntimeError, lambda: fn(torch.randn(4), torch.tensor([4])))
8284
8285    @torch._dynamo.config.patch(capture_scalar_outputs=True)
8286    def test_cat_unbacked(self):
8287        @torch.compile(backend="eager")
8288        def fn(x, y):
8289            z = y.item()
8290            return torch.cat([x, torch.ones(z)])
8291
8292        fn(torch.randn(2, 3), torch.tensor([0]))
8293        self.assertRaises(
8294            RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([1]))
8295        )
8296
8297    @torch._dynamo.config.patch(
8298        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
8299    )
8300    def test_aot_autograd_propagate_unbacked_symints_shape(self):
8301        @torch.compile(backend="aot_eager")
8302        def f(x):
8303            return torch.nonzero(x)
8304
8305        f(torch.tensor([1, 0, 3, 2, 0]))
8306
8307    def test_simple_set_usage(self):
8308        def foo(x, y):
8309            setty = {x, y}
8310            return setty.pop() * setty.pop()
8311
8312        counter = CompileCounter()
8313        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
8314        x = torch.randn(10, 10)
8315        y = torch.randn(10, 10)
8316        foo(x, y)
8317        self.assertEqual(counter.frame_count, 1)
8318
8319    def test_add_to_set(self):
8320        def foo(x, y):
8321            setty = set()
8322            setty.add(x[0])
8323            setty.add(x[1])
8324            setty.add(x[2])
8325            setty.add(y)
8326            return y * len(setty)
8327
8328        x = torch.randn(10, 10)
8329        y = torch.randn(2, 2)
8330        eager_result = foo([x, x, x, x, y], y)
8331
8332        counter = CompileCounter()
8333        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
8334        result = foo([x, x, x, x, y], y)
8335        self.assertEqual(counter.frame_count, 1)
8336        self.assertEqual(result, eager_result)
8337
8338    def test_iter_set(self):
8339        def foo(x, y):
8340            setty = set()
8341            for t in x:
8342                setty.add(t)
8343            return y * len(setty)
8344
8345        x = torch.randn(10, 10)
8346        y = torch.randn(2, 2)
8347        eager_result = foo([x, x, x, x, y], y)
8348
8349        counter = CompileCounter()
8350        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
8351        result = foo([x, x, x, x, y], y)
8352        self.assertEqual(counter.frame_count, 1)
8353        self.assertEqual(result, eager_result)
8354
8355    def test_input_set_graph_break(self):
8356        def foo(x):
8357            return x.pop() * x.pop()
8358
8359        x = torch.randn(10, 10)
8360        y = torch.randn(10, 10)
8361
8362        counter = CompileCounter()
8363
8364        inp = {x, x, x, x, y, y}
8365        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
8366
8367        # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part.
8368        # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents)
8369        # and so the guard story for the objects passed into input just isn't there atm.
8370        with self.assertRaisesRegex(
8371            torch._dynamo.exc.Unsupported,
8372            "^call_method UserDefinedObjectVariable\\(set\\).*",
8373        ):
8374            foo(inp)
8375
8376        foo = torch._dynamo.optimize(counter, nopython=False)(foo)
8377        foo(inp)
8378        self.assertEqual(counter.frame_count, 1)
8379
8380    def test_reconstruct_set_across_graph_break(self):
8381        def foo(x, y):
8382            setty = set()
8383            for t in x:
8384                setty.add(t)
8385            print("Break!")
8386            return y * len(setty)
8387
8388        x = torch.randn(10, 10)
8389        y = torch.randn(2, 2)
8390
8391        counter = CompileCounter()
8392        foo = torch._dynamo.optimize(counter)(foo)
8393        result = foo([x, x, x, x, y], y)
8394
8395    def test_set_aliasing_recompiles(self):
8396        g1 = torch.randn(10)
8397        g2 = torch.randn(10)
8398        g3 = torch.randn(10)
8399        g4 = torch.randn(10)
8400
8401        def foo(a, b, c):
8402            myset = {g1, a, b, c}
8403            return a + len(myset)
8404
8405        counter = CompileCounter()
8406        foo = torch._dynamo.optimize(counter)(foo)
8407        # first call with no aliasing
8408        foo(g2, g3, g4)
8409        self.assertEqual(counter.frame_count, 1)
8410
8411        # no aliasing again
8412        foo(g3, g2, g4)
8413        # assert no recompile
8414        self.assertEqual(counter.frame_count, 1)
8415
8416        # aliasing changes, we should recompile
8417        foo(g2, g2, g2)
8418        self.assertEqual(counter.frame_count, 2)
8419
8420        # same aliasing, different tensor
8421        foo(g3, g3, g3)
8422        self.assertEqual(counter.frame_count, 2)
8423
8424        # aliasing between global and arg, should recompile again
8425        foo(g1, g1, g1)
8426        self.assertEqual(counter.frame_count, 3)
8427
8428        # Reset
8429        torch._dynamo.reset()
8430
8431        # aliasing between global and arg, first call
8432        foo(g1, g1, g1)
8433        self.assertEqual(counter.frame_count, 4)
8434
8435        # same aliasing, different tensor, all local, recompile
8436        foo(g3, g3, g3)
8437        self.assertEqual(counter.frame_count, 5)
8438
8439        # aliasing same tensor, we shouldn't recompile
8440        foo(g2, g2, g2)
8441        self.assertEqual(counter.frame_count, 5)
8442
8443        # No aliasing
8444        foo(g2, g3, g4)
8445        self.assertEqual(counter.frame_count, 6)
8446
8447        # No aliasing again
8448        foo(g3, g2, g4)
8449        # assert no recompile
8450        self.assertEqual(counter.frame_count, 6)
8451
8452    def test_str_format_return1(self):
8453        @torch.compile(backend="eager", fullgraph=True)
8454        def fn(img):
8455            x = torch.sin(img)
8456            y = f"shape {img.shape[-2:]} batch size {img.shape[0]}"
8457            return img + x, y
8458
8459        img1 = torch.randn(1, 1, 8, 8)
8460        res, msg = fn(img1)
8461        self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1")
8462        self.assertEqual(res, img1 + torch.sin(img1))
8463
8464    def test_str_format_return2(self):
8465        @torch.compile(backend="eager", fullgraph=True)
8466        def fn(img):
8467            x = torch.sin(img)
8468            y = "shape {} batch size {y:.2f}".format(img.shape[-2:], y=img.shape[0])
8469            return img + x, y
8470
8471        img1 = torch.randn(1, 1, 8, 8)
8472        res, msg = fn(img1)
8473        self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00")
8474        self.assertEqual(res, img1 + torch.sin(img1))
8475
8476    @torch._dynamo.config.patch(capture_scalar_outputs=True)
8477    def test_validate_outputs_unbacked(self):
8478        class SillyCat(torch.autograd.Function):
8479            @staticmethod
8480            def forward(ctx, x0, x1, i):
8481                ctx.save_for_backward(i)
8482                return torch.cat([x0, x1])
8483
8484            @staticmethod
8485            def backward(ctx, grad_out):
8486                (i,) = ctx.saved_tensors
8487                i0, i1 = i.tolist()
8488                g_x0, g_x1 = grad_out.split([i0, i1])
8489                return g_x0, g_x1, None
8490
8491        @torch.compile(backend="aot_eager", fullgraph=True)
8492        def f(x, i):
8493            i0, i1 = i.tolist()
8494            x0, x1 = x.split([i0, i1])
8495            return SillyCat.apply(x0, x1, i)
8496
8497        f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
8498
8499    def test_str_format_assert1(self):
8500        @torch.compile(backend="eager", fullgraph=True)
8501        def fn(img):
8502            x = torch.sin(img)
8503            val = x.shape[-2:]
8504            torch._assert(len(val) == 2, f"shape {img.shape}")
8505            return img + x
8506
8507        img1 = torch.randn(1, 1, 8, 8)
8508        res = fn(img1)
8509        self.assertEqual(res, img1 + torch.sin(img1))
8510
8511    def test_str_format_assert2(self):
8512        cnt = CompileCounter()
8513
8514        @torch.compile(backend=cnt)
8515        def fn(img):
8516            x = torch.sin(img)
8517            torch._assert(
8518                img.shape[-2] == 8 and img.shape[-1] == 16, f"shape {img.shape}"
8519            )
8520            return img + x
8521
8522        img1 = torch.randn(1, 3, 8, 16)
8523        res = fn(img1)
8524        self.assertEqual(res, img1 + torch.sin(img1))
8525        self.assertEqual(cnt.frame_count, 1)
8526
8527        # trigger a recompile and graph break
8528        img2 = torch.randn(1, 3, 8, 15)
8529        self.assertRaises(AssertionError, lambda: fn(img2))
8530
8531    def test_tolist_scalar(self):
8532        def fn(x):
8533            new_list = []
8534            for i in x.tolist():
8535                new_list.append(i * 4)
8536            return new_list
8537
8538        x = torch.tensor([3])
8539        eager = fn(x)
8540        counter = CompileCounter()
8541        compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x)
8542        self.assertEqual(eager, compiled)
8543        self.assertEqual(counter.frame_count, 1)
8544
8545    def test_tolist_1d(self):
8546        def fn(x):
8547            new_list = []
8548            for i in x.tolist():
8549                new_list.append(i * 4)
8550            return new_list
8551
8552        x = torch.tensor([2, 1])
8553        eager = fn(x)
8554        counter = CompileCounter()
8555        compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x)
8556        self.assertEqual(eager, compiled)
8557        self.assertEqual(counter.frame_count, 1)
8558
8559    def test_tolist_kd(self):
8560        def fn(x):
8561            new_list = []
8562            for i in x.tolist():
8563                new_list.append(i * 4)
8564            return new_list
8565
8566        x = torch.tensor([[[2, 1], [2, 1], [2, 1]], [[2, 1], [2, 1], [2, 1]]])
8567        eager = fn(x)
8568        counter = CompileCounter()
8569        compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x)
8570        self.assertEqual(eager, compiled)
8571        self.assertEqual(counter.frame_count, 1)
8572
8573    @patch.object(torch._dynamo.config, "specialize_int", True)
8574    def test_tolist_0d(self):
8575        def fn(x):
8576            new_list = []
8577            i = x.tolist()
8578            new_list.append(i * 4)
8579            return new_list
8580
8581        x = torch.tensor(42)
8582        eager = fn(x)
8583        counter = CompileCounter()
8584        compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x)
8585        self.assertEqual(eager, compiled)
8586        self.assertEqual(counter.frame_count, 1)
8587
8588    @patch.object(torch._dynamo.config, "assume_static_by_default", False)
8589    @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
8590    def test_tolist_kd_dynamic(self):
8591        def fn(x):
8592            new_list = []
8593            i = x.tolist()
8594            new_list.append(i * 4)
8595            return new_list
8596
8597        x = torch.randint(3, 5, [5, 5])
8598        eager = fn(x)
8599        counter = CompileCounter()
8600        compiled_fn = torch._dynamo.optimize(counter, nopython=True)(fn)
8601        compiled = compiled_fn(x)
8602        self.assertEqual(eager, compiled)
8603        self.assertEqual(counter.frame_count, 1)
8604
8605        # Value change, no recompiles
8606        x = torch.randint(7, 9, [5, 5])
8607        compiled_fn(x)
8608        self.assertEqual(counter.frame_count, 1)
8609
8610        # Size change, forced recompiles
8611        x = torch.randint(3, 5, [3, 3])
8612        compiled_fn(x)
8613        self.assertEqual(counter.frame_count, 2)
8614
8615    def test_tolist_float(self):
8616        def fn(x):
8617            new_list = []
8618            for i in x.tolist():
8619                new_list.append(i * 4)
8620            return new_list
8621
8622        x = torch.tensor(
8623            [[[2.0, 1.0], [2.0, 1.0], [2.0, 1.0]], [[2.0, 1.0], [2.0, 1.0], [2.0, 1.0]]]
8624        )
8625        eager = fn(x)
8626        counter = CompileCounter()
8627        compiled = torch._dynamo.optimize(counter)(fn)(x)
8628        self.assertEqual(eager, compiled)
8629        # Nothing to compile here
8630        self.assertEqual(counter.frame_count, 0)
8631
8632    def test_inline_closure_not_loaded_by_parent(self):
8633        def outer(a):
8634            return a + 1
8635
8636        def indirect(x):
8637            return direct(x)
8638
8639        def direct(x):
8640            def deep2(c):
8641                return outer(c)
8642
8643            def deep(c):
8644                return deep2(c)
8645
8646            return deep(x)
8647
8648        x = torch.randn(3)
8649        eager = indirect(x)
8650        counter = CompileCounter()
8651        compiled = torch._dynamo.optimize(counter)(indirect)(x)
8652        self.assertEqual(eager, compiled)
8653        self.assertEqual(counter.frame_count, 1)
8654
8655    def test_deque_input(self):
8656        a = torch.randn([2, 3])
8657        b = torch.randn([2, 3])
8658        d1 = collections.deque([a, b])
8659        d1.insert(0, "foo")
8660
8661        d2 = collections.deque([a, b])
8662        d2.insert(0, "foo")
8663
8664        def fn(q):
8665            a = q.pop()
8666            b = q.pop()
8667            return a * b
8668
8669        eager = fn(d1)
8670        counter = CompileCounter()
8671        compiled = torch._dynamo.optimize(counter)(fn)(d2)
8672        self.assertEqual(eager, compiled)
8673        self.assertEqual(counter.frame_count, 1)
8674
8675    def test_deque_append_left(self):
8676        d1 = collections.deque([10, 10])
8677        d1.insert(0, "foo")
8678
8679        d2 = collections.deque([10, 10])
8680        d2.insert(0, "foo")
8681
8682        def fn(q, a, b):
8683            q.appendleft(a)
8684            q.appendleft(b)
8685            return q.popleft() * q.popleft()
8686
8687        a = torch.randn([3, 3])
8688        b = torch.randn([3, 3])
8689        eager = fn(d1, a, b)
8690        counter = CompileCounter()
8691        compiled = torch._dynamo.optimize(counter)(fn)(d2, a, b)
8692        self.assertEqual(eager, compiled)
8693        self.assertEqual(counter.frame_count, 1)
8694        self.assertTrue(isinstance(compiled, torch.Tensor))
8695
8696    def test_yield_from(self):
8697        def yield_from_fn(t_list, k):
8698            def yield_from_gen(l):
8699                l2 = [t * k for t in l]
8700                yield from l2
8701
8702            return [t * k for t in yield_from_gen(t_list)]
8703
8704        t_list = [torch.randn([2, 3]) for _ in range(3)]
8705        eager = yield_from_fn(t_list, 2)
8706        counter = CompileCounter()
8707        compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2)
8708        self.assertEqual(eager, compiled)
8709        self.assertEqual(counter.frame_count, 1)
8710
8711    def test_yield_from_in_a_loop(self):
8712        def gen2():
8713            yield 1
8714
8715        def gen1():
8716            for value in range(5):
8717                yield from gen2()
8718
8719        def fn(x):
8720            c = 0
8721            for i in gen1():
8722                c = c + i
8723            return x + c
8724
8725        opt_fn = torch.compile(fn, backend="eager")
8726        x = torch.zeros(4)
8727        self.assertEqual(fn(x), opt_fn(x))
8728
8729    def test_yield_gen_and_from(self):
8730        def populate_and_multiply_sequence(n, multiplier):
8731            # Inline generator
8732            def tensor_generator():
8733                for i in range(n):
8734                    yield torch.tensor([i])
8735
8736            # Use 'yield from' to iterate over tensors and multiply
8737            t_list = [tensor * multiplier for tensor in tensor_generator()]
8738
8739            def yield_from_gen():
8740                yield from t_list
8741
8742            return [t for t in yield_from_gen()]
8743
8744        multiplier = torch.tensor([10])
8745        eager = populate_and_multiply_sequence(5, multiplier)
8746        counter = CompileCounter()
8747        compiled = torch._dynamo.optimize(counter)(populate_and_multiply_sequence)(
8748            5, multiplier
8749        )
8750        self.assertEqual(eager, compiled)
8751        self.assertEqual(counter.frame_count, 1)
8752
8753    def test_yield_from_user_stop_iteration(self):
8754        class MyIter:
8755            def __init__(self, seq):
8756                self.seq = seq
8757                self.index = 0
8758
8759            def __iter__(self):
8760                return self
8761
8762            def __next__(self):
8763                self.index += 1
8764                if self.index <= len(self.seq):
8765                    return self.seq[self.index - 1]
8766                raise StopIteration(self.index)
8767
8768        def yield_from_iter_fn(seq):
8769            def gen(seq):
8770                yield from MyIter(seq)
8771
8772            return [i for i in gen(seq)]
8773
8774        seq = [torch.randn([2, 3]) for _ in range(3)]
8775        eager = yield_from_iter_fn(seq)
8776        counter = CompileCounter()
8777        compiled = torch._dynamo.optimize(counter)(yield_from_iter_fn)(seq)
8778        self.assertEqual(eager, compiled)
8779        self.assertEqual(counter.frame_count, 0)
8780
8781    def test_yield_send_to_subgenerator_graph_break(self):
8782        def subgenerator(tensor):
8783            multiplier = yield
8784            yield tensor * multiplier
8785
8786        def main_generator(t_list):
8787            for tensor in t_list:
8788                subgen = subgenerator(tensor)
8789                next(subgen)
8790                yield from subgen.send(torch.tensor([10]))
8791
8792        t_list = [torch.tensor([i]) for i in range(5)]
8793        eager = list(main_generator(t_list))
8794
8795        counter = CompileCounter()
8796        compiled_fn = torch._dynamo.optimize(counter)(main_generator)
8797        compiled = list(compiled_fn(t_list))
8798
8799        self.assertEqual(eager, compiled)
8800        self.assertEqual(counter.frame_count, 0)
8801
8802    def test_derpy_nn_module_usage(self):
8803        def ff1(x):
8804            self = mod1
8805            return torch.sigmoid(self.mod2(x) + self.param1)
8806
8807        def ff2(x):
8808            self = mod2
8809            return torch.cos(torch.sin(x) * self.param2 + 10)
8810
8811        mod1 = torch.nn.Module()
8812        mod2 = torch.nn.Module()
8813        mod1.register_module("mod2", mod2)
8814        mod1.register_parameter("param1", torch.nn.Parameter(torch.randn(10)))
8815        mod1.forward = ff1
8816        mod2.register_parameter("param2", torch.nn.Parameter(torch.randn(10)))
8817        mod2.forward = ff2
8818        mod1.eval()
8819
8820        x = torch.randn(10)
8821        expected = mod1(x)
8822        counter = CompileCounter()
8823        actual = torch.compile(mod1, backend=counter, fullgraph=True)(x)
8824        self.assertEqual(actual, expected)
8825        self.assertEqual(counter.op_count, 6)
8826
8827    def test_default_args_device_dtype(self):
8828        class Foo:
8829            def __init__(
8830                self,
8831                dtype: torch.dtype = torch.float16,
8832                device: torch.device = torch.device("cpu"),
8833            ) -> None:
8834                self.value = torch.tensor(10, dtype=dtype, device=device)
8835
8836        def fn():
8837            return Foo().value + 1
8838
8839        opt_func = torch._dynamo.optimize("eager", nopython=True)(fn)
8840        ref = fn()
8841        res = opt_func()
8842        self.assertEqual(ref, res)
8843
8844    def test_torch_device_python_type(self):
8845        for device, device_type, index in [
8846            ("cpu", "cpu", None),
8847            ("cuda:0", "cuda", 0),
8848        ]:
8849            if device == "cuda:0" and not TEST_CUDA:
8850                continue
8851
8852            def fn(target):
8853                target_device = target.device
8854                a = torch.zeros(2, 3, device=target_device)
8855                # Constant assert at trace time
8856                assert isinstance(target_device, torch.device)
8857                assert target_device.type == device_type
8858                assert target_device.index == index
8859                b = torch.zeros(2, 3, device=target_device)
8860                c = torch.zeros(2, 3, device=target_device)
8861                return a + b + c
8862
8863            from torch._dynamo.variables import ConstantVariable
8864
8865            device = torch.device(device)
8866            expected_variable = ConstantVariable(device)
8867            self.assertEqual(expected_variable.python_type(), type(device))
8868
8869            opt_func = torch._dynamo.optimize("eager", nopython=True)(fn)
8870            a = torch.tensor([2, 3], device=device)
8871            res = opt_func(a)
8872            self.assertIsInstance(res, torch.Tensor)
8873
8874    def test_torch_dtype_python_type(self):
8875        def fn(target):
8876            target_dtype = target.dtype
8877            a = torch.zeros(2, 3, dtype=target_dtype)
8878            # Constant assert at trace time
8879            assert isinstance(target_dtype, torch.dtype)
8880            b = torch.zeros(2, 3, dtype=target_dtype)
8881            c = torch.zeros(2, 3, dtype=target_dtype)
8882            return a + b + c
8883
8884        from torch._dynamo.variables import ConstantVariable
8885
8886        dtype = torch.float16
8887        expected_variable = ConstantVariable(dtype)
8888        self.assertEqual(expected_variable.python_type(), type(dtype))
8889
8890        opt_func = torch._dynamo.optimize("eager", nopython=True)(fn)
8891        a = torch.tensor([2, 3], dtype=dtype)
8892        res = opt_func(a)
8893        self.assertIsInstance(res, torch.Tensor)
8894
8895    def test_itertools_repeat(self):
8896        counters.clear()
8897
8898        def fn(x):
8899            r = itertools.repeat(100.0, 5)
8900            for i in r:
8901                x += i
8902            return x
8903
8904        x = torch.randn([2, 5])
8905        eager = fn(x)
8906
8907        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
8908        compiled = compiled_fn(x)
8909
8910        self.assertEqual(list(eager), list(compiled))
8911        self.assertEqual(len(counters["graph_break"]), 0)
8912
8913    def test_itertools_infinite_repeat(self):
8914        counters.clear()
8915
8916        def fn(x):
8917            r = itertools.repeat(100.0)
8918            idx = 0
8919            for i in r:
8920                x += i
8921                idx += 1
8922                if idx > 10:
8923                    break
8924            return x
8925
8926        x = torch.randn([2, 5])
8927        eager = fn(x)
8928
8929        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
8930        compiled = compiled_fn(x)
8931
8932        self.assertEqual(list(eager), list(compiled))
8933        self.assertEqual(len(counters["graph_break"]), 0)
8934
8935    def test_itertools_infinite_repeat_mutation(self):
8936        counters.clear()
8937
8938        def fn(x):
8939            r = itertools.repeat(x)
8940            idx = 0
8941            for i in r:
8942                x += i
8943                i += 1
8944                idx += 1
8945                if idx > 10:
8946                    break
8947            return x
8948
8949        x = torch.randn([2, 5])
8950        eager = fn(x)
8951
8952        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
8953        compiled = compiled_fn(x)
8954
8955        self.assertEqual(list(eager), list(compiled))
8956        self.assertEqual(len(counters["graph_break"]), 0)
8957
8958    def test_itertools_infinite_count(self):
8959        for args in ([], [10], [5, -1]):
8960            counters.clear()
8961
8962            def fn(x):
8963                r = itertools.count(*args)
8964                idx = 0
8965                for i in r:
8966                    x += i
8967                    idx += 1
8968                    if idx > 10:
8969                        break
8970                return x
8971
8972            x = torch.randn([2, 5])
8973            eager = fn(x)
8974
8975            compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
8976            compiled = compiled_fn(x)
8977
8978            self.assertEqual(list(eager), list(compiled))
8979            self.assertEqual(len(counters["graph_break"]), 0)
8980
8981    def test_itertools_infinite_cycle(self):
8982        counters.clear()
8983
8984        def fn(x):
8985            for iterator in (
8986                iter([]),
8987                iter([10, 11.0]),
8988                itertools.repeat(-1, 3),
8989                itertools.count(10),
8990            ):
8991                r = itertools.cycle(iterator)
8992                idx = 0
8993                x += 1
8994                for i in r:
8995                    x += i
8996                    idx += 1
8997                    if idx > 10:
8998                        break
8999            return x
9000
9001        x = torch.randn([2, 5])
9002        eager = fn(x)
9003
9004        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9005        compiled = compiled_fn(x)
9006
9007        self.assertEqual(list(eager), list(compiled))
9008        self.assertEqual(len(counters["graph_break"]), 0)
9009
9010    def test_itertools_accumulate_symint_default_sum(self):
9011        # https://github.com/pytorch/pytorch/issues/110287
9012        counters.clear()
9013
9014        def fn(x):
9015            r = itertools.accumulate([x.size(0), x.size(1)])
9016            for i in r:
9017                x *= i
9018            return x
9019
9020        x = torch.randn(2, 3)
9021        eager = fn(x)
9022
9023        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9024        compiled = compiled_fn(x)
9025
9026        self.assertEqual(list(eager), list(compiled))
9027        self.assertEqual(len(counters["graph_break"]), 0)
9028
9029    def test_itertools_accumulate_tensors_default_sum(self):
9030        counters.clear()
9031
9032        def fn(a, b, c, d, x):
9033            l = [a, b, c, d, x]
9034            for i, t in enumerate(l):
9035                l[i] = t * x
9036            return itertools.accumulate(l)
9037
9038        t_list = [torch.tensor([i + 1]) for i in range(4)]
9039        x = torch.tensor([[1, 2], [3, 4]])
9040        eager = fn(*t_list, x)
9041
9042        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9043        compiled = compiled_fn(*t_list, x)
9044
9045        self.assertEqual(list(eager), list(compiled))
9046        self.assertEqual(len(counters["graph_break"]), 0)
9047
9048    def test_itertools_accumulate_tensors_builtins(self):
9049        for builtin_op in [operator.mul, operator.sub, operator.pow]:
9050            counters.clear()
9051
9052            def fn(a, b, c, d, x):
9053                l = [a, b, c, d, x]
9054                for i, t in enumerate(l):
9055                    l[i] = t * x
9056                return itertools.accumulate(l, builtin_op)
9057
9058            t_list = [torch.tensor([i + 1]) for i in range(4)]
9059            x = torch.tensor([[1, 2], [3, 4]])
9060            eager = fn(*t_list, x)
9061
9062            compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9063            compiled = compiled_fn(*t_list, x)
9064
9065            self.assertEqual(list(eager), list(compiled))
9066            self.assertEqual(len(counters["graph_break"]), 0)
9067
9068    def test_itertools_accumulate_tensors_kwargs(self):
9069        from torch._dynamo.utils import counters
9070
9071        for kwargs in [
9072            {"func": operator.mul},
9073            {"initial": 100},
9074            {"func": operator.sub, "initial": -1},
9075        ]:
9076            counters.clear()
9077
9078            def fn(a, b, c, d, x):
9079                l = [a, b, c, d, x]
9080                for i, t in enumerate(l):
9081                    l[i] = t * x
9082                return itertools.accumulate(l, **kwargs)
9083
9084            t_list = [torch.tensor([i + 1]) for i in range(4)]
9085            x = torch.tensor([[1, 2], [3, 4]])
9086
9087            compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9088            compiled = compiled_fn(*t_list, x)
9089            eager = fn(*t_list, x)
9090
9091            self.assertEqual(list(eager), list(compiled))
9092            self.assertEqual(len(counters["graph_break"]), 0)
9093
9094    def test_packaging_version_parse(self):
9095        from packaging import version
9096
9097        @torch.compile(backend="eager", fullgraph=True)
9098        def fn():
9099            x = torch.zeros(1)
9100            if version.parse(torch.__version__) >= version.parse("2.0.0"):
9101                return x + 1
9102            return x
9103
9104        self.assertEqual(fn().item(), 1)
9105
9106    def test_itertools_accumulate_tensors_user_defined(self):
9107        def udo_fn_0(a, b):
9108            return -1
9109
9110        rando = random.randint(0, 1)
9111
9112        def udo_fn_1(a, b):
9113            return a * rando + b * rando
9114
9115        seen = []
9116
9117        def udo_fn_2(a, b):
9118            seen.append(a)
9119            seen.append(b)
9120            return a * len(seen)
9121
9122        for udo_fn in [udo_fn_0, udo_fn_1, udo_fn_2]:
9123            counters.clear()
9124            torch._dynamo.reset()
9125
9126            def fn(a, b, c, d, x):
9127                l = [a, b, c, d, x]
9128                for i, t in enumerate(l):
9129                    l[i] = t * x
9130                return itertools.accumulate(l, udo_fn)
9131
9132            t_list = [torch.tensor([i]) for i in range(4)]
9133            x = torch.tensor([[1, 2], [3, 4]])
9134            eager = fn(*t_list, x)
9135
9136            compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9137            compiled = compiled_fn(*t_list, x)
9138
9139            self.assertEqual(list(eager), list(compiled))
9140            self.assertEqual(len(counters["graph_break"]), 0)
9141
9142    def test_pure_python_accumulate(self):
9143        def accumulate(iterable, func=lambda x, y: x + y):
9144            it = iter(iterable)
9145            try:
9146                # Initialize the accumulator with the first value from the iterable
9147                accumulator = next(it)
9148            except StopIteration:
9149                # If the iterable is empty, return an empty generator
9150                return
9151            yield accumulator
9152
9153            for element in it:
9154                accumulator = func(accumulator, element)
9155                yield accumulator
9156
9157        def fn(it):
9158            return accumulate(it)
9159
9160        t_list = [torch.tensor([i]) for i in range(4)]
9161        eager = fn(t_list)
9162
9163        counter = CompileCounter()
9164        compiled_fn = torch._dynamo.optimize(counter)(fn)
9165        compiled = compiled_fn(t_list)
9166
9167        self.assertEqual(list(eager), list(compiled))
9168        self.assertEqual(counter.frame_count, 1)
9169
9170    def test_itertools_groupby_pure_python_default_identify_func(self):
9171        counters.clear()
9172
9173        def fn(l):
9174            return [(k, list(g)) for k, g in itertools.groupby(l)]
9175
9176        l = [1, 2, 2, 3, 4, 4, 4, 1, 2]
9177        eager = fn(l)
9178
9179        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9180        compiled = compiled_fn(l)
9181
9182        self.assertEqual(eager, compiled)
9183        self.assertEqual(len(counters["graph_break"]), 0)
9184
9185    def test_itertools_groupby_pure_python_key_func(self):
9186        counters.clear()
9187
9188        def fn(l):
9189            return [(k, list(g)) for k, g in itertools.groupby(l, key=operator.neg)]
9190
9191        l = [1, 2, -2, 3, 4, 4, -4, 0, -2]
9192        eager = fn(l)
9193
9194        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9195        compiled = compiled_fn(l)
9196
9197        self.assertEqual(eager, compiled)
9198        self.assertEqual(len(counters["graph_break"]), 0)
9199
9200    def test_list_iterator_contains(self):
9201        def fn(x):
9202            it = iter(["my_weight", "not_my_weight"])
9203            next(it)
9204            if "my_weight" in it:
9205                return x + 2
9206            return x + 1
9207
9208        x = torch.zeros(3)
9209        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9210
9211        self.assertEqual(fn(x), compiled_fn(x))
9212
9213    def test_storage_return(self):
9214        @torch.compile(backend="eager", fullgraph=True)
9215        def fn(x):
9216            y = torch.sin(x + 1)
9217            storage = x.untyped_storage()
9218            storage.resize_(0)
9219            y = torch.cos(y)
9220            return y, storage
9221
9222        x = torch.randn(10)
9223        expected = torch.cos(torch.sin(x + 1))
9224        y, s = fn(x)
9225        self.assertEqual(y, expected)
9226        self.assertEqual(x.untyped_storage().size(), 0)
9227        self.assertIs(s, x.untyped_storage())
9228
9229    def test_flat_name_to_original_fqn(self):
9230        class FooBarModule(torch.nn.Module):
9231            def __init__(self):
9232                super().__init__()
9233                self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4)))
9234                self.register_buffer("test_buf", torch.randn(3, 4))
9235                self.register_parameter(
9236                    "test_param", torch.nn.Parameter(torch.randn(3, 4))
9237                )
9238
9239            def forward(self, x):
9240                return ((x + self.test_buf) * getattr(self, "0")) / self.test_param
9241
9242        class TestModule(torch.nn.Module):
9243            def __init__(self):
9244                super().__init__()
9245                self.foo_bar = FooBarModule()
9246                self.register_parameter(
9247                    "test_param", torch.nn.Parameter(torch.randn(3, 4))
9248                )
9249                self.register_buffer("test_buf", torch.randn(3, 4))
9250
9251            def forward(self, x):
9252                return (self.foo_bar(x) + self.test_param) * self.test_buf
9253
9254        gm, _ = torch._dynamo.export(TestModule(), torch.randn(3, 4))
9255        self.assertIn("dynamo_flat_name_to_original_fqn", gm.meta)
9256        expected_fqn = {
9257            "L__self___test_param": "test_param",
9258            "L__self___test_buf": "test_buf",
9259            "getattr_L__self___foo_bar___0__": "foo_bar.0",
9260            "L__self___foo_bar_test_param": "foo_bar.test_param",
9261            "L__self___foo_bar_test_buf": "foo_bar.test_buf",
9262        }
9263        self.assertEqual(expected_fqn, gm.meta["dynamo_flat_name_to_original_fqn"])
9264
9265    def test_shape_env_no_recording(self):
9266        main = ShapeEnv(should_record_events=False)
9267
9268        # The main ShapeEnv should have no event recorded.
9269        self.assertEqual(len(main.events), 0)
9270
9271        # Call create_symbolic_sizes_strides_storage_offset on both of them.
9272        r = main.create_symbolic_sizes_strides_storage_offset(
9273            torch.randn(3, 2), ConstantSource("x")
9274        )
9275
9276        # Create a guard: size[0] == 3 (call evaluate_expr)
9277        #   - +1 guard entry
9278        #   - +1 replacement entry
9279        size = r[0]
9280        bool(size[0] == 3)
9281
9282        # The main ShapeEnv should remain with no event recorded.
9283        self.assertEqual(len(main.events), 0)
9284
9285        if torch.fx.experimental.validator.translation_validation_enabled():
9286            from torch.fx.experimental.symbolic_shapes import (
9287                CURRENT_NODE_KEY,
9288                SHAPEENV_EVENT_KEY,
9289            )
9290
9291            # Check that we don't store any recording metadata on nodes
9292            # from the symbolic shape FX graph.
9293            for n in main.graph.nodes:
9294                self.assertFalse(SHAPEENV_EVENT_KEY in n.meta)
9295                self.assertFalse(CURRENT_NODE_KEY in n.meta)
9296
9297    def _replay_and_check(self, shape_env: ShapeEnv):
9298        if shape_env.should_record_events:
9299            replayed = replay_shape_env_events(shape_env.events)
9300            shape_env.check_equal(replayed)
9301
9302    def test_shape_env_equal_empty(self):
9303        main, other = ShapeEnv(), ShapeEnv()
9304        main.check_equal(other)
9305        self._replay_and_check(main)
9306
9307    @onlyIfTranslationValidation
9308    def test_shape_env_equal_constructor(self):
9309        main, other = ShapeEnv(allow_scalar_outputs=False), ShapeEnv()
9310        self.assertExpectedRaisesInline(
9311            NotEqualError,
9312            lambda: main.check_equal(other),
9313            """\
9314ShapeEnv not equal: field values don't match:
9315
9316==> settings: values don't match.
9317  >  Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False)
9318  > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False)
9319""",
9320        )
9321        self._replay_and_check(main)
9322
9323    @onlyIfTranslationValidation
9324    def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self):
9325        main, other = ShapeEnv(), ShapeEnv()
9326        main.create_symbolic_sizes_strides_storage_offset(
9327            torch.randn(3, 2), ConstantSource("x")
9328        )
9329        self.assertExpectedRaisesInline(
9330            NotEqualError,
9331            lambda: main.check_equal(other),
9332            """\
9333ShapeEnv not equal: field values don't match:
9334
9335==> name_to_node: values don't match.
9336  >  Left: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9337  > Right: {}
9338==> source_to_symbol: values don't match.
9339  >  Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]}
9340  > Right: {}
9341==> val_to_var: values don't match.
9342  >  Left: {0: 0, 1: 1, 2: s1, 3: s0}
9343  > Right: {0: 0, 1: 1}
9344==> var_to_range: values don't match.
9345  >  Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
9346  > Right: {}
9347==> var_to_sources: values don't match.
9348  >  Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
9349  > Right: {}
9350==> var_to_val: values don't match.
9351  >  Left: {s0: 3, s1: 2}
9352  > Right: {}
9353""",
9354        )
9355        self._replay_and_check(main)
9356
9357    @onlyIfTranslationValidation
9358    def test_shape_env_equal_unbacked(self):
9359        main, other = ShapeEnv(), ShapeEnv()
9360        main.create_unbacked_symint()
9361        main.create_unbacked_symfloat()
9362        main.create_unbacked_symbool()
9363        self.assertExpectedRaisesInline(
9364            NotEqualError,
9365            lambda: main.check_equal(other),
9366            """\
9367ShapeEnv not equal: field values don't match:
9368
9369==> name_to_node: values don't match.
9370  >  Left: {u0, u1, zuf0}
9371  > Right: {}
9372==> unbacked_symfloat_counter: values don't match.
9373  >  Left: 1
9374  > Right: 0
9375==> unbacked_symint_counter: values don't match.
9376  >  Left: 2
9377  > Right: 0
9378==> var_to_range: values don't match.
9379  >  Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]}
9380  > Right: {}
9381""",
9382        )
9383        self._replay_and_check(main)
9384
9385    @onlyIfTranslationValidation
9386    def test_shape_env_equal_evaluate_expr_divisible(self):
9387        main, other = ShapeEnv(), ShapeEnv()
9388
9389        # Call create_symbolic_sizes_strides_storage_offset on both of them.
9390        r = main.create_symbolic_sizes_strides_storage_offset(
9391            torch.randn(3, 2), ConstantSource("x")
9392        )
9393        other.create_symbolic_sizes_strides_storage_offset(
9394            torch.randn(3, 2), ConstantSource("x")
9395        )
9396
9397        # Create a guard: size[0] % 3 == 0 (only in the main ShapeEnv)
9398        #   - +1 guard entry
9399        #   - +1 divisible entry
9400        size = r[0]
9401        bool(size[0] % 3 == 0)
9402
9403        self.assertExpectedRaisesInline(
9404            NotEqualError,
9405            lambda: main.check_equal(other),
9406            """\
9407ShapeEnv not equal: field values don't match:
9408
9409==> divisible: values don't match.
9410  >  Left: {Mod(s0, 3)}
9411  > Right: {}
9412==> guards: values don't match.
9413  >  Left: [Eq(Mod(s0, 3), 0)]
9414  > Right: []
9415==> name_to_node: values don't match.
9416  >  Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9417  > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9418""",
9419        )
9420        self._replay_and_check(main)
9421
9422    @onlyIfTranslationValidation
9423    def test_shape_env_equal_evaluate_expr_replacement(self):
9424        main, other = ShapeEnv(), ShapeEnv()
9425
9426        # Call create_symbolic_sizes_strides_storage_offset on both of them.
9427        r = main.create_symbolic_sizes_strides_storage_offset(
9428            torch.randn(3, 2), ConstantSource("x")
9429        )
9430        other.create_symbolic_sizes_strides_storage_offset(
9431            torch.randn(3, 2), ConstantSource("x")
9432        )
9433
9434        # Create a guard: size[0] == 3 (only in the main ShapeEnv)
9435        #   - +1 guard entry
9436        #   - +1 replacement entry
9437        size = r[0]
9438        bool(size[0] == 3)
9439
9440        self.assertExpectedRaisesInline(
9441            NotEqualError,
9442            lambda: main.check_equal(other),
9443            """\
9444ShapeEnv not equal: field values don't match:
9445
9446==> guards: values don't match.
9447  >  Left: [Eq(s0, 3)]
9448  > Right: []
9449==> name_to_node: values don't match.
9450  >  Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9451  > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9452==> replacements: values don't match.
9453  >  Left: {s0: 3}
9454  > Right: {}
9455==> var_to_range: values don't match.
9456  >  Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
9457  > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
9458""",
9459        )
9460        self._replay_and_check(main)
9461
9462    @onlyIfTranslationValidation
9463    def test_shape_env_equal_evaluate_expr_refinement(self):
9464        main, other = ShapeEnv(), ShapeEnv()
9465
9466        # Call create_symbolic_sizes_strides_storage_offset on both of them.
9467        r = main.create_symbolic_sizes_strides_storage_offset(
9468            torch.randn(3, 2), ConstantSource("x")
9469        )
9470        other.create_symbolic_sizes_strides_storage_offset(
9471            torch.randn(3, 2), ConstantSource("x")
9472        )
9473
9474        # Create a guard: size[0] >= 3 (only in the main ShapeEnv)
9475        #   - +1 guard entry
9476        #   - +1 var_to_guard entry
9477        #   - Change: var_to_range
9478        size = r[0]
9479        bool(size[0] >= 3)
9480
9481        self.assertExpectedRaisesInline(
9482            NotEqualError,
9483            lambda: main.check_equal(other),
9484            """\
9485ShapeEnv not equal: field values don't match:
9486
9487==> guards: values don't match.
9488  >  Left: [s0 >= 3]
9489  > Right: []
9490==> name_to_node: values don't match.
9491  >  Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9492  > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9493==> var_to_range: values don't match.
9494  >  Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
9495  > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
9496""",
9497        )
9498        self._replay_and_check(main)
9499
9500    @onlyIfTranslationValidation
9501    def test_shape_env_equal_runtime_assert(self):
9502        main, other = ShapeEnv(), ShapeEnv()
9503
9504        # Call create_unbacked_symint on both of them.
9505        r = main.create_unbacked_symint()
9506        other.create_unbacked_symint()
9507
9508        # Create a runtime assert: r % 3 == 0 (only in the main ShapeEnv)
9509        #   - +1 deferred_runtime_asserts entry
9510        #   - Change: num_deferred_runtime_asserts
9511        expect_true(r % 3 == 0)
9512
9513        self.assertExpectedRaisesInline(
9514            NotEqualError,
9515            lambda: main.check_equal(other),
9516            """\
9517ShapeEnv not equal: field values don't match:
9518
9519==> deferred_runtime_asserts: values don't match.
9520  >  Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
9521  > Right: {}
9522==> name_to_node: values don't match.
9523  >  Left: {_assert, eq, mod, u0}
9524  > Right: {u0}
9525==> num_deferred_runtime_asserts: values don't match.
9526  >  Left: 1
9527  > Right: 0
9528""",
9529        )
9530        self._replay_and_check(main)
9531
9532    def test_shape_env_recorded_function_fallback(self):
9533        # Make sure the record/replay mechanism for ShapeEnv will fallback
9534        # if no ShapeEnv instance is found.
9535        constrain_range(5, min=2, max=10)
9536        constrain_unify(5, 5)
9537
9538        self.assertExpectedRaisesInline(
9539            AssertionError,
9540            lambda: _constrain_range_for_size(5, min=2, max=10),
9541            """can only constrain range for SymInt""",
9542        )
9543
9544    def test_default_dtype_change(self):
9545        @torch.compile
9546        def foo():
9547            def inner(a, b, res_dtype):
9548                print(a, b, res_dtype)
9549                self.assertEqual(torch.result_type(a, b), res_dtype)
9550
9551            inner(torch.tensor(1, device="cpu"), 1.0, torch.get_default_dtype())
9552
9553        with set_default_dtype(torch.float):
9554            foo()
9555        with set_default_dtype(torch.double):
9556            foo()
9557
9558    def test_numpy_ufunc_out(self):
9559        @torch.compile(backend="eager")
9560        def foo():
9561            x = np.arange(5)
9562            out = np.empty((x.shape[0], x.shape[0]))
9563            res_out = np.sin(x, out=out)
9564            assert res_out is out
9565
9566        foo()
9567
9568    # Unfortunately, we don't currently preserve the ids of
9569    # res_out and out correctly across the graph break
9570    @unittest.expectedFailure
9571    def test_numpy_ufunc_out_graph_break(self):
9572        @torch.compile(backend="eager")
9573        def foo():
9574            x = np.arange(5)
9575            out = np.empty((x.shape[0], x.shape[0]))
9576            res_out = np.sin(x, out=out)
9577            torch._dynamo.graph_break()
9578            assert res_out is out
9579
9580        foo()
9581
9582    def test_dict_subclass_cannot_be_initialized_in_graph(self):
9583        for super_class in (
9584            collections.OrderedDict,
9585            dict,
9586        ):
9587
9588            class CustomDict(super_class):
9589                def __init__(self, *args, **kwargs):
9590                    super().__init__(*args, **kwargs)
9591
9592            def fn(x):
9593                c = CustomDict()
9594                c["key"] = x
9595                assert "key" in c
9596                return c["key"] + 1
9597
9598            fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
9599            with self.assertRaisesRegex(
9600                torch._dynamo.exc.Unsupported, "call_function UserDefinedClassVariable"
9601            ):
9602                print(fn_opt(torch.zeros(1)))
9603
9604    @wrapDeterministicFlagAPITest
9605    def test_backward_deterministic_mode_mismatch_warning(self):
9606        @torch.compile
9607        def func(a, b):
9608            return a + b
9609
9610        for forward_deterministic, backward_deterministic in itertools.product(
9611            [True, False], [True, False]
9612        ):
9613            torch.use_deterministic_algorithms(forward_deterministic)
9614            a = torch.randn(10, requires_grad=True)
9615            res = func(a, 1)
9616            grad = torch.ones_like(res)
9617            torch.use_deterministic_algorithms(backward_deterministic)
9618
9619            if not forward_deterministic and backward_deterministic:
9620                with self.assertRaisesRegex(
9621                    RuntimeError,
9622                    "^This compiled backward function is being run with torch\.use_deterministic_algorithms",
9623                ):
9624                    res.backward(grad)
9625
9626            else:
9627                res.backward(grad)
9628
9629    def test_torch_dynamo_codegen_pow(self):
9630        def pow(x):
9631            return x**2
9632
9633        x = np.arange(8)
9634        pow_opt = torch.compile(pow)
9635
9636        actual, source_code = run_and_get_code(pow_opt, x)
9637        expect = pow(x)
9638
9639        self.assertEqual(expect, actual)
9640
9641        self.assertTrue(
9642            all("aten.pow" not in code for code in source_code),
9643            msg="Encountered an unexpected fallback to 'aten pow' in dynamo compiled code",
9644        )
9645
9646    def test_graph_break_compilation_metrics(self):
9647        def fn(x):
9648            x.cos()
9649            torch._dynamo.graph_break()
9650            x.sin()
9651            torch._dynamo.graph_break()
9652            return x.cos()
9653
9654        torch._dynamo.utils.clear_compilation_metrics()
9655        x = torch.rand((4, 4))
9656        f = torch.compile(fn, backend="eager")
9657        f(x)
9658        metrics = torch._dynamo.utils.get_compilation_metrics()
9659        # Should only be one restart per event
9660        (restart_reason,) = metrics[0].restart_reasons
9661        self.assertTrue(
9662            "skip function graph_break" in restart_reason,
9663            "Should have logged graph break reason",
9664        )
9665        self.assertTrue(
9666            metrics[0].dynamo_time_before_restart_s
9667            <= metrics[0].entire_frame_compile_time_s
9668        )
9669
9670        (restart_reason,) = metrics[1].restart_reasons
9671        self.assertTrue(
9672            "skip function graph_break" in restart_reason,
9673            "Should have logged graph break reason",
9674        )
9675        self.assertTrue(
9676            metrics[1].dynamo_time_before_restart_s
9677            <= metrics[1].entire_frame_compile_time_s
9678        )
9679
9680        # No restarts
9681        self.assertTrue(
9682            len(metrics[2].restart_reasons) == 0, "Last compile has no graph break"
9683        )
9684        self.assertTrue(metrics[2].dynamo_time_before_restart_s == 0)
9685
9686    def test_graph_break_compilation_metrics_on_failure(self):
9687        def fn(x):
9688            return x.sin()
9689
9690        def broken_backend(gm, example_inputs):
9691            raise RuntimeError("broken backend")
9692
9693        x = torch.rand((4, 4))
9694        f = torch.compile(fn, backend=broken_backend)
9695        with unittest.mock.patch("torch._dynamo.config.suppress_errors", True):
9696            torch._dynamo.utils.clear_compilation_metrics()
9697            f(x)
9698            metrics = torch._dynamo.utils.get_compilation_metrics()
9699            for metric in metrics:
9700                self.assertTrue(metric.dynamo_time_before_restart_s > 0)
9701                self.assertTrue(
9702                    "RuntimeError: broken backend" in metric.fail_reason,
9703                    "Should have logged fail reason",
9704                )
9705
9706    def test_compilation_metrics_size_limit(self):
9707        def fn1(x):
9708            return x.relu()
9709
9710        def fn2(x):
9711            return x.cos()
9712
9713        def fn3(x):
9714            return x.sin()
9715
9716        def fn4(x):
9717            return x.exp()
9718
9719        import contextlib
9720
9721        @contextlib.contextmanager
9722        def metrics_limit_ctx():
9723            try:
9724                torch._dynamo.utils.set_compilation_metrics_limit(3)
9725                yield
9726            finally:
9727                torch._dynamo.utils.set_compilation_metrics_limit(
9728                    torch._dynamo.utils.DEFAULT_COMPILATION_METRICS_LIMIT
9729                )
9730
9731        x = torch.rand((4, 4))
9732        torch._dynamo.reset()
9733        torch.compile(fn1, backend="eager")(x)
9734        torch.compile(fn2, backend="eager")(x)
9735        torch.compile(fn3, backend="eager")(x)
9736        torch.compile(fn4, backend="eager")(x)
9737
9738        with metrics_limit_ctx():
9739            torch._dynamo.utils.clear_compilation_metrics()
9740            torch._dynamo.reset()
9741            self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics()))
9742            torch.compile(fn1, backend="eager")(x)
9743            self.assertEqual(1, len(torch._dynamo.utils.get_compilation_metrics()))
9744            torch.compile(fn2, backend="eager")(x)
9745            self.assertEqual(2, len(torch._dynamo.utils.get_compilation_metrics()))
9746            torch.compile(fn3, backend="eager")(x)
9747            self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics()))
9748            torch.compile(fn4, backend="eager")(x)
9749            self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics()))
9750
9751    def test_funcname_cache(self):
9752        src = """\
9753import torch
9754if True:
9755    test = 3
9756
9757class AAA:
9758    class DUMMY:
9759        class DUMMY2:
9760            pass
9761
9762    def dummy(self):
9763        def dummy2():
9764            pass
9765    class BBB:
9766        @staticmethod
9767        def CCC():
9768            class DDD:
9769                if True:
9770                    @staticmethod
9771                    def EEE():
9772                        x = [torch.ones(3, 3) for _ in range(5)]
9773                        return x
9774            return DDD
9775def fn():
9776    return 3
9777"""
9778        with tempfile.NamedTemporaryFile(mode="w") as f:
9779            f.write(src)
9780            f.flush()
9781            from torch._dynamo.funcname_cache import get_funcname
9782
9783            names = [get_funcname(f.name, i + 1) for i in range(src.count("\n") + 1)]
9784
9785        self.assertExpectedInline(
9786            "\n".join(names),
9787            """\
9788
9789
9790
9791
9792AAA
9793AAA.DUMMY
9794AAA.DUMMY.DUMMY2
9795AAA.DUMMY.DUMMY2
9796AAA.DUMMY.DUMMY2
9797AAA.dummy
9798AAA.dummy.dummy2
9799AAA.dummy.dummy2
9800AAA.BBB
9801AAA.BBB
9802AAA.BBB.CCC
9803AAA.BBB.CCC.DDD
9804AAA.BBB.CCC.DDD
9805AAA.BBB.CCC.DDD
9806AAA.BBB.CCC.DDD.EEE
9807AAA.BBB.CCC.DDD.EEE
9808AAA.BBB.CCC.DDD.EEE
9809AAA.BBB.CCC
9810fn
9811fn
9812""",
9813        )
9814
9815    def test_return_dict_with_graph_break_and_update(self):
9816        def create():
9817            torch._dynamo.graph_break()
9818            return {0: torch.tensor(3)}
9819
9820        def fn():
9821            return {**create()}
9822
9823        opt_fn = torch.compile(backend="eager")(fn)
9824        result = opt_fn()
9825        self.assertIn(0, result)
9826        self.assertTrue(same(result[0], torch.tensor(3)))
9827
9828    def test_dynamo_reset_clears_cache(self):
9829        """Test that dynamo bytecode cache is freed
9830        when dynamo reset is called
9831        """
9832
9833        def fn(x):
9834            return torch.sin(x)
9835
9836        opt_fn = torch.compile(backend="eager")(fn)
9837        opt_fn(torch.randn(3, 3))
9838
9839        c1 = _debug_get_cache_entry_list(fn.__code__)
9840        self.assertEqual(len(c1), 1)
9841
9842        torch._dynamo.reset()
9843        c2 = _debug_get_cache_entry_list(fn.__code__)
9844        self.assertEqual(len(c2), 0)
9845
9846    @torch._dynamo.config.patch(capture_scalar_outputs=True)
9847    def test_guard_size_oblivious(self):
9848        # This code, in fact, does NOT work in eager
9849        @torch.compile(backend="eager", fullgraph=True)
9850        def fn(x):
9851            y = torch.zeros(x.item())
9852            if guard_size_oblivious(y.size(0) == 0):
9853                assert False
9854            return y
9855
9856        self.assertEqual(fn(torch.tensor([0])), torch.zeros(0))
9857
9858    def test_guard_size_oblivious_backed(self):
9859        @torch.compile(backend="eager", fullgraph=True)
9860        def f(x):
9861            y = x.size(0)
9862            # This doesn't actually do anything
9863            if guard_size_oblivious(y == 0):
9864                return torch.randn(1)
9865            else:
9866                return torch.randn(2)
9867
9868        # Should not fail in either case
9869        self.assertEqual(f(torch.randn(0)).shape, (1,))
9870        self.assertEqual(f(torch.randn(2)).shape, (2,))
9871
9872    def _test_compile_model_free(self, model_inp_ctr, weakref_watch):
9873        """
9874        Args:
9875        model_inp_ctr
9876            - constructor that returns a new model and inputs to that model
9877        weakref_watch
9878            - function that returns a layer of the model for weakref to
9879              finalize on, so we can check that the layer is freed after
9880              the model goes out of scope
9881        """
9882        cleared = False
9883
9884        def finalize():
9885            nonlocal cleared
9886            cleared = True
9887
9888        def run():
9889            mod, inp = model_inp_ctr()
9890            weakref.finalize(weakref_watch(mod), finalize)
9891            torch.compile(mod, backend="eager")(inp)
9892
9893        run()
9894        gc.collect()
9895        self.assertTrue(cleared)
9896
9897    def test_custom_module_free(self):
9898        """Test that a model is freed when it goes out of scope"""
9899
9900        class Mod(torch.nn.Module):
9901            def __init__(self):
9902                super(Mod, self).__init__()
9903                self.fc = torch.nn.Linear(100, 100)
9904
9905            def forward(self, out):
9906                return self.fc(out)
9907
9908        self._test_compile_model_free(
9909            lambda: (Mod(), torch.randn(100, 100)),
9910            lambda mod: mod.fc,
9911        )
9912
9913    def test_sequential_module_free(self):
9914        self._test_compile_model_free(
9915            lambda: (
9916                torch.nn.Sequential(
9917                    torch.nn.Linear(100, 100),
9918                    torch.nn.ReLU(),
9919                ),
9920                torch.randn(100, 100),
9921            ),
9922            lambda mod: mod[0],
9923        )
9924
9925    def test_linear_module_free(self):
9926        self._test_compile_model_free(
9927            lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)),
9928            lambda mod: mod,
9929        )
9930
9931    # The following 2 tests fail due to https://github.com/python/cpython/issues/118013.
9932    # Tracked by https://github.com/pytorch/pytorch/issues/124302.
9933    # The xfails can be removed once Python 3.12 is updated on CI.
9934    @xfailIfPy312
9935    @unittest.skipIf(True, "Skipping this test for release/2.4")
9936    def test_outside_linear_module_free(self):
9937        # Compared to test_linear_module_free, the linear
9938        # layer is not the code object that is directly compiled.
9939
9940        # This test does not use _test_compile_model_free because of difficulty
9941        # in handling variable fc.
9942
9943        cleared = False
9944
9945        def finalize():
9946            nonlocal cleared
9947            cleared = True
9948
9949        def run():
9950            fc = torch.nn.Linear(100, 100)
9951
9952            class Mod(torch.nn.Module):
9953                def __init__(self):
9954                    super().__init__()
9955                    self.fc_ref = fc
9956
9957                def forward(self, x):
9958                    return self.fc_ref(x)
9959
9960            mod = Mod()
9961            inp = torch.randn(100, 100)
9962            weakref.finalize(fc, finalize)
9963            torch.compile(mod, backend="eager")(inp)
9964
9965        run()
9966        # del fc  # This should delete all the references
9967        gc.collect()
9968        self.assertTrue(cleared)
9969
9970    @xfailIfPy312
9971    def test_parameter_free(self):
9972        def model_inp_ctr():
9973            param = torch.nn.Parameter(torch.randn(100, 100))
9974
9975            class Mod(torch.nn.Module):
9976                def __init__(self):
9977                    super().__init__()
9978                    self.param = param
9979
9980                def forward(self, x):
9981                    return self.param * x[0]
9982
9983            # return param to keep it alive in _test_compile_model_free
9984            return Mod(), (torch.randn(100, 100), param)
9985
9986        self._test_compile_model_free(model_inp_ctr, lambda mod: mod.param)
9987
9988    def test_conditional_list_comp_in_context(self):
9989        def fn(inp):
9990            try:
9991                return [torch.sin(x) for x in inp if x is not None]
9992            except Exception:
9993                pass
9994
9995        inp = [torch.randn(3, 3) for _ in range(3)] + [None]
9996        opt_fn = torch.compile(fn, backend="eager")
9997        opt_fn(inp)
9998
9999    def test_312_binary_slice_with_graph_break1(self):
10000        l1 = torch.nn.Linear(5, 5)
10001        l2 = torch.nn.Linear(5, 5)
10002
10003        def fn(x):
10004            # causes a graph break with items in the stack
10005            n = torch.nn.Sequential(l1, l2)
10006            out = n[1:](x)
10007            return out
10008
10009        opt_fn = torch.compile(fn, backend="eager")
10010        opt_fn(torch.randn(5, 5))
10011
10012    def test_312_binary_slice_with_graph_break2(self):
10013        class Foo:
10014            def __setitem__(self, key, val):
10015                pass
10016
10017            def __getitem__(self, key):
10018                torch._dynamo.graph_break()
10019                return 1
10020
10021        foo = Foo()
10022
10023        def fn(x):
10024            # graph break in a STORE_SLICE instruction
10025            foo[:] = x
10026            # graph break in BINARY_SLICE with has_backedge check
10027            x = x + foo[:]
10028            if x is None:
10029                x = x + 1
10030            else:
10031                x = x + 1
10032            return x
10033
10034        opt_fn = torch.compile(fn, backend="eager")
10035        opt_fn(torch.randn(5, 5))
10036
10037    def test_super_after_graph_break(self):
10038        class Foo(torch.nn.Sequential):
10039            def __init__(self, layers):
10040                torch._dynamo.graph_break()
10041                super().__init__(*layers)
10042
10043        def fn(x):
10044            layers = [torch.nn.Linear(3, 3) for _ in range(3)]
10045            mod = Foo(layers)
10046            return mod(x)
10047
10048        opt_fn = torch.compile(fn, backend="eager")
10049        opt_fn(torch.randn(3, 3))
10050
10051    def test_load_fast_and_clear_graph_break(self):
10052        # Can result in a segfault in 3.12+ if LOAD_FAST_AND_CLEAR
10053        # is not handled properly in a graph break
10054        def fn():
10055            out = torch.cat([torch.randn(r, 5) for r in range(3)])
10056            torch._dynamo.graph_break()
10057            out = torch.cat([torch.randn(r, 5) for r in range(3)])
10058            return out
10059
10060        self.assertEqual(torch._dynamo.optimize("eager")(fn)().shape, (3, 5))
10061
10062    def test_raises_importerror1(self):
10063        @torch.compile(backend="eager")
10064        def fn(x):
10065            try:
10066                import some_module_that_surely_does_not_exist
10067
10068                return
10069            except ImportError:
10070                pass
10071            return x.sin()
10072
10073        x = torch.randn(8)
10074        self.assertEqual(fn(x), x.sin())
10075
10076    def test_raises_importerror2(self):
10077        @torch.compile(backend="eager")
10078        def fn(x):
10079            import some_module_that_surely_does_not_exist
10080
10081            return x + 1
10082
10083        x = torch.randn(8)
10084        with self.assertRaises(ImportError):
10085            fn(x)
10086
10087    def test_dynamo_cache_move_to_front(self):
10088        def fn(x, const):
10089            return x + const
10090
10091        # dynamic=False forces Dynamo to recompile
10092        opt_fn = torch.compile(fn, backend="eager", dynamic=False)
10093
10094        inp = torch.randn(3, 3)
10095
10096        # NOTE: assumes that each cache entry is guarded
10097        # on unique Mod instance
10098        opt_fn(inp, 1)
10099        opt_fn(inp, 2)
10100        opt_fn(inp, 3)
10101
10102        c1 = _debug_get_cache_entry_list(fn.__code__)
10103        self.assertEqual(len(c1), 3)
10104
10105        # move cache entry to front
10106        opt_fn(inp, 2)
10107        c2 = _debug_get_cache_entry_list(fn.__code__)
10108        self.assertIs(c1[1], c2[0])
10109
10110    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
10111    def test_dynamo_cache_invalidate(self):
10112        class Mod(torch.nn.Module):
10113            def __init__(self):
10114                super(Mod, self).__init__()
10115                self.fc = torch.nn.Linear(3, 3)
10116
10117            def forward(self, out):
10118                return self.fc(out)
10119
10120        def fn(x, mod):
10121            return mod(x)
10122
10123        opt_fn = torch.compile(fn, backend="eager")
10124
10125        m1 = Mod()
10126        m2 = Mod()
10127        m3 = Mod()
10128        inp = torch.randn(3, 3)
10129
10130        # NOTE: assumes that each cache entry is guarded
10131        # on unique Mod instance
10132        opt_fn(inp, m1)
10133        opt_fn(inp, m2)
10134        opt_fn(inp, m3)
10135
10136        c1 = _debug_get_cache_entry_list(fn.__code__)
10137        self.assertEqual(len(c1), 3)
10138
10139        # move cache entry to front
10140        opt_fn(inp, m2)
10141        c2 = _debug_get_cache_entry_list(fn.__code__)
10142        self.assertIs(c1[1], c2[0])
10143
10144        # delete center of cache
10145        del m3
10146        c3 = _debug_get_cache_entry_list(fn.__code__)
10147        self.assertEqual(len(c3), 2)
10148        self.assertIs(c3[0], c2[0])
10149        self.assertIs(c3[1], c2[2])
10150
10151        # delete end of cache
10152        del m1
10153        c4 = _debug_get_cache_entry_list(fn.__code__)
10154        self.assertEqual(len(c4), 1)
10155        self.assertIs(c4[0], c3[0])
10156
10157        del m2
10158        c5 = _debug_get_cache_entry_list(fn.__code__)
10159        self.assertEqual(len(c5), 0)
10160
10161    def test_grad_none(self):
10162        def fn(x, y):
10163            x.grad = torch.abs(y)
10164            x.grad.add_(y)
10165            return torch.abs(y)
10166
10167        y = torch.arange(4).reshape(2, 2).to(torch.float)
10168        x = torch.randn(2, 2)
10169        x.grad = None
10170
10171        z = fn(x, y)
10172        ref_y = torch.clone(z).detach()
10173        ref_x_grad = torch.clone(x.grad).detach()
10174
10175        y = torch.arange(4).reshape(2, 2).to(torch.float)
10176        x = torch.randn(2, 2)
10177        x.grad = None
10178
10179        opt_fn = torch.compile(fn, backend="eager")
10180        z = opt_fn(x, y)
10181        self.assertEqual(z, ref_y)
10182        self.assertEqual(x.grad, ref_x_grad)
10183
10184    def test_grad_non_none(self):
10185        def fn(x, y):
10186            x.grad.add_(y)
10187            return torch.abs(y)
10188
10189        y = torch.ones(2, 2)
10190        x = torch.randn(2, 2)
10191        x.grad = torch.arange(4).reshape(2, 2).to(torch.float)
10192
10193        z = fn(x, y)
10194        ref_y = torch.clone(z).detach()
10195        ref_x_grad = torch.clone(x.grad).detach()
10196
10197        y = torch.ones(2, 2)
10198        x = torch.randn(2, 2)
10199        x.grad = torch.arange(4).reshape(2, 2).to(torch.float)
10200
10201        cnt = torch._dynamo.testing.CompileCounterWithBackend("eager")
10202        opt_fn = torch.compile(fn, backend=cnt)
10203        z = opt_fn(x, y)
10204
10205        # Ensure that the generated graph returns only one output. We want the
10206        # add_ on the grad to be part of the graph itself, so that inductor can
10207        # theoretically move the add_ and resutling copy_ nodes at the right
10208        # place to free memory.
10209        self.assertEqual(len(list(cnt.graphs[0].graph.nodes)[-1].all_input_nodes), 1)
10210        self.assertEqual(z, ref_y)
10211        self.assertEqual(x.grad, ref_x_grad)
10212
10213    def test_new_with_int_list(self):
10214        # Make sure torch.Tensor.new(int argument list) behaves the same on dynamo.
10215        def fn(x):
10216            return x.new(*x.size()) + 5
10217
10218        optfn = torch.compile(backend="eager")(fn)
10219
10220        x = torch.arange(10).view(2, 5)
10221
10222        expected = fn(x)
10223        actual = optfn(x)
10224
10225        self.assertEqual(expected.dtype, actual.dtype)
10226        self.assertEqual(expected.shape, actual.shape)
10227        self.assertEqual(expected.stride(), actual.stride())
10228        self.assertEqual(expected.storage_offset(), actual.storage_offset())
10229
10230    @torch._dynamo.config.patch(guard_nn_modules=True)
10231    def test_hasattr_nn_module_guard(self):
10232        class M(torch.nn.Module):
10233            def __init__(self):
10234                super().__init__()
10235                self.a = torch.nn.Linear(3, 3)
10236
10237            def forward(self, x):
10238                if hasattr(self, "a"):
10239                    return self.a(x)
10240                else:
10241                    return x
10242
10243        m = M()
10244        x = torch.randn(3, 3)
10245        ref = m(x)
10246
10247        opt_m = torch.compile(backend="eager")(m)
10248        res = opt_m(x)
10249        self.assertEqual(ref, res)
10250
10251    def test_ordered_dict_move_to_end(self):
10252        d = {
10253            "foo": 1,
10254            "bar": 2,
10255        }
10256
10257        d = collections.OrderedDict(d)
10258        d.move_to_end("foo")
10259
10260        @torch.compile(backend="eager")
10261        def fn(x, d):
10262            return x * d["foo"] * d["bar"]
10263
10264        fn(torch.randn(4), d)
10265        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10266            fn(torch.randn(4), d)
10267
10268    def test_defaultdict(self):
10269        d = collections.defaultdict()
10270        d["foo"] = 1
10271        d["bar"] = 2
10272
10273        @torch.compile(backend="eager")
10274        def fn(x, d):
10275            return x * d["foo"] * d["bar"]
10276
10277        fn(torch.randn(4), d)
10278        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10279            fn(torch.randn(4), d)
10280
10281    def test_custom_dict(self):
10282        class MyDict(dict):
10283            pass
10284
10285        d = {
10286            "foo": 1,
10287            "bar": 2,
10288        }
10289
10290        d = MyDict(d)
10291
10292        @torch.compile(backend="eager")
10293        def fn(x, d):
10294            return x * d["foo"] * d["bar"]
10295
10296        fn(torch.randn(4), d)
10297        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10298            fn(torch.randn(4), d)
10299
10300    @unittest.skipIf(not TEST_CUDA, "requires cuda")
10301    @torch._dynamo.config.patch(
10302        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
10303    )
10304    @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
10305    def test_interpolate_propagate_real_tensors(self):
10306        @torch.compile(backend="eager", fullgraph=True)
10307        def f(mask, box):
10308            # u0, u1 = mask.tolist()
10309            mask = torch.randn(1, 1, 30, 30, device="cuda")
10310            h, w = box.tolist()
10311            return torch.nn.functional.interpolate(
10312                mask, (h, w), mode="bilinear", align_corners=False
10313            )
10314
10315        f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda"))
10316
10317    def test_custom_iter_dict(self):
10318        class ReversedDict(dict):
10319            def __iter__(self):
10320                return reversed(list(self.keys()))
10321
10322        d = {
10323            "foo": 1,
10324            "bar": 2,
10325        }
10326
10327        d = ReversedDict(d)
10328
10329        @torch.compile(backend="eager")
10330        def fn(x, d):
10331            return x * d["foo"] * d["bar"]
10332
10333        fn(torch.randn(4), d)
10334        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10335            fn(torch.randn(4), d)
10336
10337    def test_custom_keys_iter_dict(self):
10338        class ReversedDict(dict):
10339            def keys(self):
10340                return ["bar", "foo"]
10341
10342        d = {
10343            "foo": 1,
10344            "bar": 2,
10345        }
10346
10347        d = ReversedDict(d)
10348
10349        @torch.compile(backend="eager")
10350        def fn(x, d):
10351            return x * d["foo"] * d["bar"]
10352
10353        fn(torch.randn(4), d)
10354        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10355            fn(torch.randn(4), d)
10356
10357    def test_dict_guard_on_keys_order(self):
10358        d = {
10359            2: 4,
10360            3: 5,
10361        }
10362
10363        cnts = torch._dynamo.testing.CompileCounter()
10364
10365        def fn(x, d):
10366            for key, value in d.items():
10367                x = x * key + value
10368            return x
10369
10370        opt_fn = torch.compile(fn, backend=cnts)
10371        opt_fn(torch.randn(4), d)
10372        opt_fn(torch.randn(4), d)
10373        # No recompilation
10374        self.assertEqual(cnts.frame_count, 1)
10375
10376        # move 2 to the end
10377        d[2] = d.pop(2)
10378
10379        x = torch.randn(4)
10380        res = opt_fn(x, d)
10381        # Check recompilation
10382        self.assertEqual(cnts.frame_count, 2)
10383        self.assertEqual(res, fn(x, d))
10384
10385    def test_dict_guard_on_keys_order2(self):
10386        d = {
10387            2: 4,
10388            3: 5,
10389        }
10390
10391        cnts = torch._dynamo.testing.CompileCounter()
10392
10393        def fn(x, d):
10394            for key in d:
10395                value = d[key]
10396                x = x * key + value
10397            return x
10398
10399        opt_fn = torch.compile(fn, backend=cnts)
10400        opt_fn(torch.randn(4), d)
10401        opt_fn(torch.randn(4), d)
10402        # No recompilation
10403        self.assertEqual(cnts.frame_count, 1)
10404
10405        # move 2 to the end
10406        d[2] = d.pop(2)
10407
10408        x = torch.randn(4)
10409        res = opt_fn(x, d)
10410        # Check recompilation
10411        self.assertEqual(cnts.frame_count, 2)
10412        self.assertEqual(res, fn(x, d))
10413
10414    def test_contains_dunder_dict(self):
10415        class UserDefined:
10416            def __init__(self):
10417                self.a = 3
10418                self.b = 5
10419
10420            def run(self, x):
10421                if "a" in self.__dict__:
10422                    x = x * self.a
10423                if "b" in self.__dict__:
10424                    x = x * self.b
10425                self.c = 7
10426                if "c" in self.__dict__:
10427                    x = x * self.c
10428                return x * self.__dict__.get("a") * self.__dict__.get("z", 2)
10429
10430        obj = UserDefined()
10431
10432        def fn(x):
10433            return obj.run(x)
10434
10435        x = torch.randn(4)
10436        ref = fn(x)
10437        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
10438        res = opt_fn(x)
10439        self.assertEqual(ref, res)
10440
10441    def test_module_dunder_dict(self):
10442        class MyModule(torch.nn.Module):
10443            def __init__(self):
10444                super().__init__()
10445                self.foo = 1
10446                self.bar = 2
10447                self.baz = 3
10448
10449            def forward(self, x):
10450                if "foo" in self.__dict__:
10451                    return x * self.bar
10452                return x * self.baz
10453
10454        mod = MyModule()
10455        x = torch.randn(10)
10456        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
10457        self.assertEqual(mod(x), opt_mod(x))
10458
10459
10460class TestTracer(JitTestCase):
10461    def test_jit_save(self):
10462        def fn():
10463            class Foo(torch.nn.Module):
10464                def __init__(self):
10465                    super().__init__()
10466                    self.a = 3
10467
10468                @torch.jit.export
10469                def __getstate__(self):
10470                    return (3, self.training)
10471
10472                @torch.jit.export
10473                def __setstate__(self, state):
10474                    self.a = state[0]
10475                    self.training = state[1]
10476
10477                def forward(self, x):
10478                    return x + self.a
10479
10480            f = Foo()
10481
10482            return torch.jit.trace(f, (torch.rand(3, 4),))
10483
10484        fn()
10485        opt_fn = torch._dynamo.optimize("eager")(fn)
10486        opt_fn()
10487
10488
10489if __name__ == "__main__":
10490    from torch._dynamo.test_case import run_tests
10491
10492    run_tests()
10493