xref: /aosp_15_r20/external/pytorch/test/dynamo/test_misc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport abc
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport copy
5*da0073e9SAndroid Build Coastguard Workerimport dataclasses
6*da0073e9SAndroid Build Coastguard Workerimport dis
7*da0073e9SAndroid Build Coastguard Workerimport enum
8*da0073e9SAndroid Build Coastguard Workerimport functools
9*da0073e9SAndroid Build Coastguard Workerimport gc
10*da0073e9SAndroid Build Coastguard Workerimport itertools
11*da0073e9SAndroid Build Coastguard Workerimport logging
12*da0073e9SAndroid Build Coastguard Workerimport math
13*da0073e9SAndroid Build Coastguard Workerimport operator
14*da0073e9SAndroid Build Coastguard Workerimport os
15*da0073e9SAndroid Build Coastguard Workerimport random
16*da0073e9SAndroid Build Coastguard Workerimport sys
17*da0073e9SAndroid Build Coastguard Workerimport tempfile
18*da0073e9SAndroid Build Coastguard Workerimport threading
19*da0073e9SAndroid Build Coastguard Workerimport traceback
20*da0073e9SAndroid Build Coastguard Workerimport typing
21*da0073e9SAndroid Build Coastguard Workerimport unittest
22*da0073e9SAndroid Build Coastguard Workerimport unittest.mock as mock
23*da0073e9SAndroid Build Coastguard Workerimport warnings
24*da0073e9SAndroid Build Coastguard Workerimport weakref
25*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerimport numpy as np
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerimport torch
30*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerimport torch._inductor.test_case
33*da0073e9SAndroid Build Coastguard Workerimport torch.onnx.operators
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
36*da0073e9SAndroid Build Coastguard Workerimport torch.utils.cpp_extension
37*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
38*da0073e9SAndroid Build Coastguard Workerfrom torch._C import FileCheck
39*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo import allow_in_graph
40*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.eval_frame import _debug_get_cache_entry_list
41*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.exc import Unsupported
42*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.source import ConstantSource, GetItemSource, LocalSource
43*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import (
44*da0073e9SAndroid Build Coastguard Worker    CompileCounter,
45*da0073e9SAndroid Build Coastguard Worker    CompileCounterWithBackend,
46*da0073e9SAndroid Build Coastguard Worker    expectedFailureDynamic,
47*da0073e9SAndroid Build Coastguard Worker    same,
48*da0073e9SAndroid Build Coastguard Worker    skipIfNotPy311,
49*da0073e9SAndroid Build Coastguard Worker    unsupported,
50*da0073e9SAndroid Build Coastguard Worker    xfailIfPy312,
51*da0073e9SAndroid Build Coastguard Worker)
52*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault
53*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.utils import run_and_get_code
54*da0073e9SAndroid Build Coastguard Workerfrom torch.ao.quantization import MinMaxObserver
55*da0073e9SAndroid Build Coastguard Workerfrom torch.ao.quantization.fake_quantize import FakeQuantize
56*da0073e9SAndroid Build Coastguard Workerfrom torch.ao.quantization.qconfig import QConfig
57*da0073e9SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantize_fx import prepare_qat_fx
58*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.recording import NotEqualError, replay_shape_env_events
59*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import (
60*da0073e9SAndroid Build Coastguard Worker    _constrain_range_for_size,
61*da0073e9SAndroid Build Coastguard Worker    constrain_range,
62*da0073e9SAndroid Build Coastguard Worker    constrain_unify,
63*da0073e9SAndroid Build Coastguard Worker    ConstraintViolationError,
64*da0073e9SAndroid Build Coastguard Worker    expect_true,
65*da0073e9SAndroid Build Coastguard Worker    guard_size_oblivious,
66*da0073e9SAndroid Build Coastguard Worker    ShapeEnv,
67*da0073e9SAndroid Build Coastguard Worker)
68*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F
69*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
70*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import (
71*da0073e9SAndroid Build Coastguard Worker    PLATFORM_SUPPORTS_FLASH_ATTENTION,
72*da0073e9SAndroid Build Coastguard Worker    SM80OrLater,
73*da0073e9SAndroid Build Coastguard Worker    TEST_CUDA,
74*da0073e9SAndroid Build Coastguard Worker    TEST_MULTIGPU,
75*da0073e9SAndroid Build Coastguard Worker)
76*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
77*da0073e9SAndroid Build Coastguard Worker    sample_inputs_take_along_dim,
78*da0073e9SAndroid Build Coastguard Worker)
79*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
80*da0073e9SAndroid Build Coastguard Worker    freeze_rng_state,
81*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
82*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
83*da0073e9SAndroid Build Coastguard Worker    wrapDeterministicFlagAPITest,
84*da0073e9SAndroid Build Coastguard Worker)
85*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
86*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_utils import logs_to_string
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Workermytuple = collections.namedtuple("mytuple", ["a", "b", "ab"])
89*da0073e9SAndroid Build Coastguard WorkerT = typing.TypeVar("T")
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker# Specializes a test to run only if translation validation is set.
93*da0073e9SAndroid Build Coastguard Workerdef onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable:
94*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(fn)
95*da0073e9SAndroid Build Coastguard Worker    def wrapper(*args, **kwargs):
96*da0073e9SAndroid Build Coastguard Worker        import torch.fx.experimental.validator
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        if torch.fx.experimental.validator.translation_validation_enabled():
99*da0073e9SAndroid Build Coastguard Worker            return fn(*args, **kwargs)
100*da0073e9SAndroid Build Coastguard Worker        raise unittest.SkipTest(f"only works when TV is True.")
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    return wrapper
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Workerdef cleanup_op(opname):
106*da0073e9SAndroid Build Coastguard Worker    ns, name = opname.split("::")
107*da0073e9SAndroid Build Coastguard Worker    if not hasattr(torch.ops, ns):
108*da0073e9SAndroid Build Coastguard Worker        return
109*da0073e9SAndroid Build Coastguard Worker    actual_ns = getattr(torch.ops, ns)
110*da0073e9SAndroid Build Coastguard Worker    if not hasattr(actual_ns, name):
111*da0073e9SAndroid Build Coastguard Worker        return
112*da0073e9SAndroid Build Coastguard Worker    delattr(actual_ns, name)
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Workerclass MyPickledModule(torch.nn.Module):
116*da0073e9SAndroid Build Coastguard Worker    def __init__(self, z):
117*da0073e9SAndroid Build Coastguard Worker        super().__init__()
118*da0073e9SAndroid Build Coastguard Worker        self.z = z
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, y):
121*da0073e9SAndroid Build Coastguard Worker        return x * x * x + y + self.z
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker# These are used for test_{cond/map}_with_quantization
125*da0073e9SAndroid Build Coastguard Workerdefault_symmetric_fake_quant = FakeQuantize.with_args(
126*da0073e9SAndroid Build Coastguard Worker    observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
127*da0073e9SAndroid Build Coastguard Worker)
128*da0073e9SAndroid Build Coastguard Workerdefault_weight_symmetric_fake_quant = FakeQuantize.with_args(
129*da0073e9SAndroid Build Coastguard Worker    observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
130*da0073e9SAndroid Build Coastguard Worker)
131*da0073e9SAndroid Build Coastguard Workeruniform_qconfig_8bit = QConfig(
132*da0073e9SAndroid Build Coastguard Worker    activation=default_symmetric_fake_quant,
133*da0073e9SAndroid Build Coastguard Worker    weight=default_weight_symmetric_fake_quant.with_args,
134*da0073e9SAndroid Build Coastguard Worker)
135*da0073e9SAndroid Build Coastguard Workerqconfig_dict = {"object_type": [(torch.nn.Linear, uniform_qconfig_8bit)]}
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Workerdef closure_adder(val):
139*da0073e9SAndroid Build Coastguard Worker    def inner(x):
140*da0073e9SAndroid Build Coastguard Worker        return torch.sin(x + val)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    return inner
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Workerclass UserDefineSetAttr:
146*da0073e9SAndroid Build Coastguard Worker    setup = False
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    def __setattr__(self, key, value):
149*da0073e9SAndroid Build Coastguard Worker        assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup
150*da0073e9SAndroid Build Coastguard Worker        super().__setattr__(f"pfx_{key}", value)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, key, c=1):
153*da0073e9SAndroid Build Coastguard Worker        assert torch.compiler.is_dynamo_compiling() or UserDefineSetAttr.setup
154*da0073e9SAndroid Build Coastguard Worker        # c is added to force a guard on __defaults__ and checks the source for __getattr__
155*da0073e9SAndroid Build Coastguard Worker        if c:
156*da0073e9SAndroid Build Coastguard Worker            return self.__dict__[f"pfx_{key}"]
157*da0073e9SAndroid Build Coastguard Worker        else:
158*da0073e9SAndroid Build Coastguard Worker            return None
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Workerclass MiscTests(torch._inductor.test_case.TestCase):
162*da0073e9SAndroid Build Coastguard Worker    def test_get_cache_entry(self):
163*da0073e9SAndroid Build Coastguard Worker        def f(x):
164*da0073e9SAndroid Build Coastguard Worker            return x + 1
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        torch.compile(f)(torch.randn(5, 5, 5))
167*da0073e9SAndroid Build Coastguard Worker        entries = _debug_get_cache_entry_list(f)
168*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(entries) > 0)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        def g(x):
171*da0073e9SAndroid Build Coastguard Worker            return x + 2
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker        entries = _debug_get_cache_entry_list(g)
174*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(entries) == 0)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        try:
177*da0073e9SAndroid Build Coastguard Worker            _debug_get_cache_entry_list(1)
178*da0073e9SAndroid Build Coastguard Worker        except TypeError as e:
179*da0073e9SAndroid Build Coastguard Worker            self.assertIn("expected a code object!", str(e))
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        # test get cache entry on skipped code object
182*da0073e9SAndroid Build Coastguard Worker        def h(x):
183*da0073e9SAndroid Build Coastguard Worker            x = x + 1
184*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
185*da0073e9SAndroid Build Coastguard Worker            return x + 1
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        torch.compile(h)(torch.randn(3, 3))
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        entries = _debug_get_cache_entry_list(torch._dynamo.graph_break)
190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(entries), 0)
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    def test_boolarg(self):
193*da0073e9SAndroid Build Coastguard Worker        def boolarg(aa, bb, flag):
194*da0073e9SAndroid Build Coastguard Worker            if flag:
195*da0073e9SAndroid Build Coastguard Worker                return aa - bb
196*da0073e9SAndroid Build Coastguard Worker            else:
197*da0073e9SAndroid Build Coastguard Worker                return bb - aa
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, 10)
200*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10, 10)
201*da0073e9SAndroid Build Coastguard Worker        correct1 = boolarg(a, b, True)
202*da0073e9SAndroid Build Coastguard Worker        correct2 = boolarg(a, b, False)
203*da0073e9SAndroid Build Coastguard Worker        correct3 = boolarg(a, b, None)
204*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
205*da0073e9SAndroid Build Coastguard Worker        opt_boolarg = torch._dynamo.optimize_assert(counter)(boolarg)
206*da0073e9SAndroid Build Coastguard Worker        val1 = opt_boolarg(a, b, True)
207*da0073e9SAndroid Build Coastguard Worker        val2 = opt_boolarg(a, b, False)
208*da0073e9SAndroid Build Coastguard Worker        val3 = opt_boolarg(a, b, None)
209*da0073e9SAndroid Build Coastguard Worker        val4 = opt_boolarg(a, b, True)
210*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(val1, correct1))
211*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(val2, correct2))
212*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(val3, correct3))
213*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(val4, correct1))
214*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 3)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    def test_invalid_args_builtin(self):
217*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
218*da0073e9SAndroid Build Coastguard Worker        def fn(x):
219*da0073e9SAndroid Build Coastguard Worker            x = x.sin()
220*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, torch.Tensor, invalid=True):
221*da0073e9SAndroid Build Coastguard Worker                x = x.sin()
222*da0073e9SAndroid Build Coastguard Worker            return x
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
225*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(16))
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    def test_cpp_extension_recommends_custom_ops(self):
228*da0073e9SAndroid Build Coastguard Worker        cpp_source = """
229*da0073e9SAndroid Build Coastguard Worker        #include <torch/extension.h>
230*da0073e9SAndroid Build Coastguard Worker        at::Tensor foobar(const at::Tensor& x) {
231*da0073e9SAndroid Build Coastguard Worker            return x.clone();
232*da0073e9SAndroid Build Coastguard Worker        }
233*da0073e9SAndroid Build Coastguard Worker        """
234*da0073e9SAndroid Build Coastguard Worker        module = torch.utils.cpp_extension.load_inline(
235*da0073e9SAndroid Build Coastguard Worker            name="mylib",
236*da0073e9SAndroid Build Coastguard Worker            cpp_sources=cpp_source,
237*da0073e9SAndroid Build Coastguard Worker            functions="foobar",
238*da0073e9SAndroid Build Coastguard Worker            verbose=True,
239*da0073e9SAndroid Build Coastguard Worker        )
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2, requires_grad=True)
242*da0073e9SAndroid Build Coastguard Worker        counters.clear()
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
245*da0073e9SAndroid Build Coastguard Worker        def f(x):
246*da0073e9SAndroid Build Coastguard Worker            return module.foobar(x)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
249*da0073e9SAndroid Build Coastguard Worker            UserWarning,
250*da0073e9SAndroid Build Coastguard Worker            ".*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*",
251*da0073e9SAndroid Build Coastguard Worker        ):
252*da0073e9SAndroid Build Coastguard Worker            f(x)
253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 1)
254*da0073e9SAndroid Build Coastguard Worker        first_graph_break = list(counters["graph_break"].keys())[0]
255*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
256*da0073e9SAndroid Build Coastguard Worker            first_graph_break,
257*da0073e9SAndroid Build Coastguard Worker            """Graph break due to unsupported builtin mylib.PyCapsule.foobar. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.""",
258*da0073e9SAndroid Build Coastguard Worker        )
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        cpp_source = """
261*da0073e9SAndroid Build Coastguard Worker        #include <torch/extension.h>
262*da0073e9SAndroid Build Coastguard Worker        at::Tensor baz(const at::Tensor& x) {
263*da0073e9SAndroid Build Coastguard Worker            return x.clone();
264*da0073e9SAndroid Build Coastguard Worker        }
265*da0073e9SAndroid Build Coastguard Worker        """
266*da0073e9SAndroid Build Coastguard Worker        module2 = torch.utils.cpp_extension.load_inline(
267*da0073e9SAndroid Build Coastguard Worker            name="mylib2",
268*da0073e9SAndroid Build Coastguard Worker            cpp_sources=cpp_source,
269*da0073e9SAndroid Build Coastguard Worker            functions="baz",
270*da0073e9SAndroid Build Coastguard Worker            verbose=True,
271*da0073e9SAndroid Build Coastguard Worker        )
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        # Test that each warning only happens once
276*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
277*da0073e9SAndroid Build Coastguard Worker        def f(x):
278*da0073e9SAndroid Build Coastguard Worker            module2.baz(x)
279*da0073e9SAndroid Build Coastguard Worker            module.foobar(x)
280*da0073e9SAndroid Build Coastguard Worker            module.foobar(x)
281*da0073e9SAndroid Build Coastguard Worker            module2.baz(x)
282*da0073e9SAndroid Build Coastguard Worker            module.foobar(x)
283*da0073e9SAndroid Build Coastguard Worker            module2.baz(x)
284*da0073e9SAndroid Build Coastguard Worker            return x.clone()
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as ws:
287*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
288*da0073e9SAndroid Build Coastguard Worker            f(x)
289*da0073e9SAndroid Build Coastguard Worker            f(x)
290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(ws), 2)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker    def test_callpacked(self):
293*da0073e9SAndroid Build Coastguard Worker        def call_packed(args):
294*da0073e9SAndroid Build Coastguard Worker            a, b, c = args
295*da0073e9SAndroid Build Coastguard Worker            return a - b * c
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
298*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, 10)
299*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10, 10)
300*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(10, 10)
301*da0073e9SAndroid Build Coastguard Worker        correct = call_packed([a, b, c])
302*da0073e9SAndroid Build Coastguard Worker        opt_call_packed = torch._dynamo.optimize_assert(counter)(call_packed)
303*da0073e9SAndroid Build Coastguard Worker        val1 = opt_call_packed([a, b, c])
304*da0073e9SAndroid Build Coastguard Worker        val2 = opt_call_packed((a, b, c))
305*da0073e9SAndroid Build Coastguard Worker        val3 = opt_call_packed([a, b, c])
306*da0073e9SAndroid Build Coastguard Worker        val4 = opt_call_packed((a, b, c))
307*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(val1, correct))
308*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(val2, correct))
309*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(val3, correct))
310*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(val4, correct))
311*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker    def test_raises(self):
314*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, cls):
315*da0073e9SAndroid Build Coastguard Worker            x = a + b - c * 10
316*da0073e9SAndroid Build Coastguard Worker            raise cls(str(x))
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
319*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, 10)
320*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10, 10)
321*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(10, 10)
322*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(counter)(fn)
323*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(AssertionError, lambda: opt_fn(a, b, c, AssertionError))
324*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
325*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 3)
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    def test_module_not_callable(self):
328*da0073e9SAndroid Build Coastguard Worker        def fn(x):
329*da0073e9SAndroid Build Coastguard Worker            return torch.fft(x)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
332*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, 10)
333*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(counter)(fn)
334*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
335*da0073e9SAndroid Build Coastguard Worker            TypeError, "'module' object is not callable", lambda: opt_fn(a)
336*da0073e9SAndroid Build Coastguard Worker        )
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker    def test_inplace(self):
339*da0073e9SAndroid Build Coastguard Worker        def inplace1(a, b):
340*da0073e9SAndroid Build Coastguard Worker            o = torch.empty((10, 10))
341*da0073e9SAndroid Build Coastguard Worker            o.copy_(a)
342*da0073e9SAndroid Build Coastguard Worker            o -= b
343*da0073e9SAndroid Build Coastguard Worker            return o
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3)
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker    def test_inplace_desugaring(self):
348*da0073e9SAndroid Build Coastguard Worker        def inplace_on_literals(y):
349*da0073e9SAndroid Build Coastguard Worker            x0 = 1
350*da0073e9SAndroid Build Coastguard Worker            x0 += y
351*da0073e9SAndroid Build Coastguard Worker            x1 = 1
352*da0073e9SAndroid Build Coastguard Worker            x1 -= y
353*da0073e9SAndroid Build Coastguard Worker            return x0, x1
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
356*da0073e9SAndroid Build Coastguard Worker            self, inplace_on_literals, 1, expected_ops=2
357*da0073e9SAndroid Build Coastguard Worker        )
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    def test_unpack4(self):
360*da0073e9SAndroid Build Coastguard Worker        def unpack4(a, b):
361*da0073e9SAndroid Build Coastguard Worker            a = a[:5, :]
362*da0073e9SAndroid Build Coastguard Worker            b = b[:5, :]
363*da0073e9SAndroid Build Coastguard Worker            x, y = a.size()
364*da0073e9SAndroid Build Coastguard Worker            o = torch.empty((x, y))
365*da0073e9SAndroid Build Coastguard Worker            o.copy_(a / b)
366*da0073e9SAndroid Build Coastguard Worker            return o
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
369*da0073e9SAndroid Build Coastguard Worker            self,
370*da0073e9SAndroid Build Coastguard Worker            unpack4,
371*da0073e9SAndroid Build Coastguard Worker            2,
372*da0073e9SAndroid Build Coastguard Worker            expected_ops=5,
373*da0073e9SAndroid Build Coastguard Worker            expected_ops_dynamic=ifdynstaticdefault(5, 7),
374*da0073e9SAndroid Build Coastguard Worker        )
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker    def test_unpack5(self):
377*da0073e9SAndroid Build Coastguard Worker        def unpack5(a, b):
378*da0073e9SAndroid Build Coastguard Worker            a = a[:5, :]
379*da0073e9SAndroid Build Coastguard Worker            b = b[:5, :]
380*da0073e9SAndroid Build Coastguard Worker            x, y = a.shape
381*da0073e9SAndroid Build Coastguard Worker            o = torch.empty((x, y))
382*da0073e9SAndroid Build Coastguard Worker            o.copy_(a / b)
383*da0073e9SAndroid Build Coastguard Worker            return o
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
386*da0073e9SAndroid Build Coastguard Worker            self,
387*da0073e9SAndroid Build Coastguard Worker            unpack5,
388*da0073e9SAndroid Build Coastguard Worker            2,
389*da0073e9SAndroid Build Coastguard Worker            expected_ops=5,
390*da0073e9SAndroid Build Coastguard Worker            expected_ops_dynamic=ifdynstaticdefault(5, 7),
391*da0073e9SAndroid Build Coastguard Worker        )
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    def test_matmul1(self):
394*da0073e9SAndroid Build Coastguard Worker        def matmul_op1(a, b):
395*da0073e9SAndroid Build Coastguard Worker            return a @ b
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        # TODO(jansel): FX doesn't support this, should add upstream support
398*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, matmul_op1, 2, expected_ops=1)
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker    def test_int_shape_binops(self):
401*da0073e9SAndroid Build Coastguard Worker        def fn(x):
402*da0073e9SAndroid Build Coastguard Worker            # Test reversal by putting int arg first.
403*da0073e9SAndroid Build Coastguard Worker            y = 15 - x.shape[0]
404*da0073e9SAndroid Build Coastguard Worker            y = 4 + y
405*da0073e9SAndroid Build Coastguard Worker            y = 5 * y
406*da0073e9SAndroid Build Coastguard Worker            y = 2 % y
407*da0073e9SAndroid Build Coastguard Worker            y = 3**y
408*da0073e9SAndroid Build Coastguard Worker            y = 10 // y
409*da0073e9SAndroid Build Coastguard Worker            y = pow(2, y)
410*da0073e9SAndroid Build Coastguard Worker            y = 10 / y
411*da0073e9SAndroid Build Coastguard Worker            return x + y
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
414*da0073e9SAndroid Build Coastguard Worker            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 11)
415*da0073e9SAndroid Build Coastguard Worker        )
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
418*da0073e9SAndroid Build Coastguard Worker    def test_pt2_compliant_ops_are_allowed(self):
419*da0073e9SAndroid Build Coastguard Worker        lib = torch.library.Library("mylib", "FRAGMENT")
420*da0073e9SAndroid Build Coastguard Worker        try:
421*da0073e9SAndroid Build Coastguard Worker            torch.library.define(
422*da0073e9SAndroid Build Coastguard Worker                "mylib::bar",
423*da0073e9SAndroid Build Coastguard Worker                "(Tensor x) -> Tensor",
424*da0073e9SAndroid Build Coastguard Worker                lib=lib,
425*da0073e9SAndroid Build Coastguard Worker                tags=(torch.Tag.pt2_compliant_tag,),
426*da0073e9SAndroid Build Coastguard Worker            )
427*da0073e9SAndroid Build Coastguard Worker            torch.library.impl(
428*da0073e9SAndroid Build Coastguard Worker                "mylib::bar", "CompositeImplicitAutograd", torch.sin, lib=lib
429*da0073e9SAndroid Build Coastguard Worker            )
430*da0073e9SAndroid Build Coastguard Worker            assert torch.Tag.pt2_compliant_tag in torch.ops.mylib.bar.default.tags
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker            def f(x):
433*da0073e9SAndroid Build Coastguard Worker                return torch.ops.mylib.bar(x)
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker            overload = torch.ops.mylib.bar.default
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker            def g(x):
438*da0073e9SAndroid Build Coastguard Worker                return overload(x)
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker            counts = torch._dynamo.testing.CompileCounter()
443*da0073e9SAndroid Build Coastguard Worker            optimized_f = torch._dynamo.optimize(counts, nopython=True)(f)
444*da0073e9SAndroid Build Coastguard Worker            _ = optimized_f(x)
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker            optimized_g = torch._dynamo.optimize(counts, nopython=True)(f)
447*da0073e9SAndroid Build Coastguard Worker            _ = optimized_g(x)
448*da0073e9SAndroid Build Coastguard Worker        finally:
449*da0073e9SAndroid Build Coastguard Worker            cleanup_op("mylib::bar")
450*da0073e9SAndroid Build Coastguard Worker            del lib
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
453*da0073e9SAndroid Build Coastguard Worker    def test_non_pt2_compliant_ops_graph_break(self):
454*da0073e9SAndroid Build Coastguard Worker        lib = torch.library.Library("mylib", "FRAGMENT")
455*da0073e9SAndroid Build Coastguard Worker        try:
456*da0073e9SAndroid Build Coastguard Worker            torch.library.define("mylib::bar2", "(Tensor x) -> Tensor", lib=lib)
457*da0073e9SAndroid Build Coastguard Worker            torch.library.impl(
458*da0073e9SAndroid Build Coastguard Worker                "mylib::bar2", "CompositeImplicitAutograd", torch.sin, lib=lib
459*da0073e9SAndroid Build Coastguard Worker            )
460*da0073e9SAndroid Build Coastguard Worker            assert torch.Tag.pt2_compliant_tag not in torch.ops.mylib.bar2.default.tags
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker            def f(x):
463*da0073e9SAndroid Build Coastguard Worker                return torch.ops.mylib.bar2(x)
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker            overload = torch.ops.mylib.bar2.default
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker            def g(x):
468*da0073e9SAndroid Build Coastguard Worker                return overload(x)
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker            counts = torch._dynamo.testing.CompileCounter()
473*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
474*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.exc.Unsupported, "not PT2 compliant"
475*da0073e9SAndroid Build Coastguard Worker            ):
476*da0073e9SAndroid Build Coastguard Worker                optimized_f = torch._dynamo.optimize(counts, nopython=True)(f)
477*da0073e9SAndroid Build Coastguard Worker                y = optimized_f(x)
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
480*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.exc.Unsupported, "not PT2 compliant"
481*da0073e9SAndroid Build Coastguard Worker            ):
482*da0073e9SAndroid Build Coastguard Worker                optimized_g = torch._dynamo.optimize(counts, nopython=True)(f)
483*da0073e9SAndroid Build Coastguard Worker                y = optimized_g(x)
484*da0073e9SAndroid Build Coastguard Worker        finally:
485*da0073e9SAndroid Build Coastguard Worker            cleanup_op("mylib::bar2")
486*da0073e9SAndroid Build Coastguard Worker            del lib
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
489*da0073e9SAndroid Build Coastguard Worker    def test_pt2_compliant_overload(self):
490*da0073e9SAndroid Build Coastguard Worker        lib = torch.library.Library("mylib", "FRAGMENT")
491*da0073e9SAndroid Build Coastguard Worker        try:
492*da0073e9SAndroid Build Coastguard Worker            torch.library.define(
493*da0073e9SAndroid Build Coastguard Worker                "mylib::bar3.tensor",
494*da0073e9SAndroid Build Coastguard Worker                "(Tensor x) -> Tensor",
495*da0073e9SAndroid Build Coastguard Worker                tags=torch.Tag.pt2_compliant_tag,
496*da0073e9SAndroid Build Coastguard Worker                lib=lib,
497*da0073e9SAndroid Build Coastguard Worker            )
498*da0073e9SAndroid Build Coastguard Worker            torch.library.define(
499*da0073e9SAndroid Build Coastguard Worker                "mylib::bar3.int", "(Tensor x, int dim) -> Tensor", lib=lib
500*da0073e9SAndroid Build Coastguard Worker            )
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker            torch.library.impl(
503*da0073e9SAndroid Build Coastguard Worker                "mylib::bar3.tensor",
504*da0073e9SAndroid Build Coastguard Worker                "CompositeImplicitAutograd",
505*da0073e9SAndroid Build Coastguard Worker                torch.sin,
506*da0073e9SAndroid Build Coastguard Worker                lib=lib,
507*da0073e9SAndroid Build Coastguard Worker            )
508*da0073e9SAndroid Build Coastguard Worker            torch.library.impl(
509*da0073e9SAndroid Build Coastguard Worker                "mylib::bar3.int", "CompositeImplicitAutograd", torch.sum, lib=lib
510*da0073e9SAndroid Build Coastguard Worker            )
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker            def f(x):
513*da0073e9SAndroid Build Coastguard Worker                return torch.ops.mylib.bar3(x)
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker            def g(x):
516*da0073e9SAndroid Build Coastguard Worker                return torch.ops.mylib.bar3(x, 1)
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker            def h(x):
519*da0073e9SAndroid Build Coastguard Worker                return torch.ops.mylib.bar3(x, x, x)
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker            counts = torch._dynamo.testing.CompileCounter()
524*da0073e9SAndroid Build Coastguard Worker            optimized_f = torch._dynamo.optimize(counts, nopython=True)(f)
525*da0073e9SAndroid Build Coastguard Worker            optimized_g = torch._dynamo.optimize(counts, nopython=True)(g)
526*da0073e9SAndroid Build Coastguard Worker            optimized_h = torch._dynamo.optimize(counts, nopython=True)(h)
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker            # No error: the overload is PT2 compliant
529*da0073e9SAndroid Build Coastguard Worker            optimized_f(x)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
532*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.exc.Unsupported, "not PT2 compliant"
533*da0073e9SAndroid Build Coastguard Worker            ):
534*da0073e9SAndroid Build Coastguard Worker                y = optimized_g(x)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker            # graph break on incorrect parsing
537*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "failed to"):
538*da0073e9SAndroid Build Coastguard Worker                y = optimized_h(x)
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker        finally:
541*da0073e9SAndroid Build Coastguard Worker            cleanup_op("mylib::bar3")
542*da0073e9SAndroid Build Coastguard Worker            del lib
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker    def test_auto_functionalize_can_with_default(self):
545*da0073e9SAndroid Build Coastguard Worker        lib = torch.library.Library("mylib", "FRAGMENT")
546*da0073e9SAndroid Build Coastguard Worker        torch.library.define(
547*da0073e9SAndroid Build Coastguard Worker            "mylib::foo",
548*da0073e9SAndroid Build Coastguard Worker            "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()",
549*da0073e9SAndroid Build Coastguard Worker            tags=torch.Tag.pt2_compliant_tag,
550*da0073e9SAndroid Build Coastguard Worker            lib=lib,
551*da0073e9SAndroid Build Coastguard Worker        )
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker        @torch.library.impl("mylib::foo", "cpu", lib=lib)
554*da0073e9SAndroid Build Coastguard Worker        def foo_impl(a, b, c=None, d=None, e=-1):
555*da0073e9SAndroid Build Coastguard Worker            a + b
556*da0073e9SAndroid Build Coastguard Worker            return
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker        def f(a, mode):
559*da0073e9SAndroid Build Coastguard Worker            return torch.ops.mylib.foo(
560*da0073e9SAndroid Build Coastguard Worker                a,
561*da0073e9SAndroid Build Coastguard Worker                0,
562*da0073e9SAndroid Build Coastguard Worker            )
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([10, 10, 10], dtype=torch.int64)
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker        torch.compile(f)(a, 0)
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker        cleanup_op("mylib::foo")
569*da0073e9SAndroid Build Coastguard Worker        del lib
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker    def test_user_defined_setattr1(self):
572*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
573*da0073e9SAndroid Build Coastguard Worker        def fn(obj):
574*da0073e9SAndroid Build Coastguard Worker            obj.y = obj.x + 1
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker        obj = UserDefineSetAttr()
577*da0073e9SAndroid Build Coastguard Worker        with patch.object(UserDefineSetAttr, "setup", True):
578*da0073e9SAndroid Build Coastguard Worker            obj.x = torch.randn(8)
579*da0073e9SAndroid Build Coastguard Worker        fn(obj)
580*da0073e9SAndroid Build Coastguard Worker        with patch.object(UserDefineSetAttr, "setup", True):
581*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(obj.y, obj.x + 1)
582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"})
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker    def test_user_defined_setattr2(self):
585*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
586*da0073e9SAndroid Build Coastguard Worker        def fn(x):
587*da0073e9SAndroid Build Coastguard Worker            obj = UserDefineSetAttr()
588*da0073e9SAndroid Build Coastguard Worker            obj.x = x
589*da0073e9SAndroid Build Coastguard Worker            obj.y = obj.x + 1
590*da0073e9SAndroid Build Coastguard Worker            return obj
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(8)
593*da0073e9SAndroid Build Coastguard Worker        obj = fn(x)
594*da0073e9SAndroid Build Coastguard Worker        with patch.object(UserDefineSetAttr, "setup", True):
595*da0073e9SAndroid Build Coastguard Worker            self.assertIs(obj.x, x)
596*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(obj.y, x + 1)
597*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(obj.__dict__.keys(), {"pfx_x", "pfx_y"})
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker    def test_closure_recompiles(self):
600*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker        def fn(x, other_fn):
603*da0073e9SAndroid Build Coastguard Worker            return other_fn(x + 1) - 1
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        opt = torch.compile(fn, backend=cnt, fullgraph=True)
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(8)
608*da0073e9SAndroid Build Coastguard Worker        for f in (
609*da0073e9SAndroid Build Coastguard Worker            closure_adder(5),
610*da0073e9SAndroid Build Coastguard Worker            closure_adder(5),
611*da0073e9SAndroid Build Coastguard Worker            closure_adder(torch.randn(8)),
612*da0073e9SAndroid Build Coastguard Worker            closure_adder(torch.randn(8)),
613*da0073e9SAndroid Build Coastguard Worker        ):
614*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(opt(x, f), fn(x, f))
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker    def test_generate_trivial_abstract_impl(self):
619*da0073e9SAndroid Build Coastguard Worker        try:
620*da0073e9SAndroid Build Coastguard Worker            lib = torch.library.Library("mylib", "FRAGMENT")
621*da0073e9SAndroid Build Coastguard Worker            torch.library.define(
622*da0073e9SAndroid Build Coastguard Worker                "mylib::foo",
623*da0073e9SAndroid Build Coastguard Worker                "(Tensor x, Tensor[] y, Tensor(a!)? z, SymInt w) -> ()",
624*da0073e9SAndroid Build Coastguard Worker                tags=torch.Tag.pt2_compliant_tag,
625*da0073e9SAndroid Build Coastguard Worker                lib=lib,
626*da0073e9SAndroid Build Coastguard Worker            )
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker            @torch.library.impl("mylib::foo", "cpu", lib=lib)
629*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.disable
630*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, y, z, w):
631*da0073e9SAndroid Build Coastguard Worker                x + y[0] + w
632*da0073e9SAndroid Build Coastguard Worker                return
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker            def f(x, y, z, w):
635*da0073e9SAndroid Build Coastguard Worker                return torch.ops.mylib.foo(x, y, z, 2)
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
638*da0073e9SAndroid Build Coastguard Worker            y = (torch.randn(3), torch.randn(3))
639*da0073e9SAndroid Build Coastguard Worker            z = torch.randn(3)
640*da0073e9SAndroid Build Coastguard Worker            w = torch.randn(3)
641*da0073e9SAndroid Build Coastguard Worker            args = (x, y, z, w)
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker            output = torch.compile(f, backend="eager", fullgraph=True)(*args)
644*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output, None)
645*da0073e9SAndroid Build Coastguard Worker        finally:
646*da0073e9SAndroid Build Coastguard Worker            cleanup_op("mylib::foo")
647*da0073e9SAndroid Build Coastguard Worker            del lib
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker    def test_can_auto_functionalize(self):
650*da0073e9SAndroid Build Coastguard Worker        from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker        expected_true = [
653*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x) -> ()",
654*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
655*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
656*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor",
657*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)",
658*da0073e9SAndroid Build Coastguard Worker        ]
659*da0073e9SAndroid Build Coastguard Worker        expected_false = [
660*da0073e9SAndroid Build Coastguard Worker            "(Tensor x) -> ()",
661*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a) x) -> Tensor(a)",
662*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x) -> Tensor(a!)",
663*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
664*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)",
665*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
666*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
667*da0073e9SAndroid Build Coastguard Worker            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])",
668*da0073e9SAndroid Build Coastguard Worker        ]
669*da0073e9SAndroid Build Coastguard Worker        for schema in expected_true:
670*da0073e9SAndroid Build Coastguard Worker            try:
671*da0073e9SAndroid Build Coastguard Worker                lib = torch.library.Library("mylib", "FRAGMENT")
672*da0073e9SAndroid Build Coastguard Worker                torch.library.define("mylib::a", schema, lib=lib)
673*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
674*da0073e9SAndroid Build Coastguard Worker                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
675*da0073e9SAndroid Build Coastguard Worker                )
676*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
677*da0073e9SAndroid Build Coastguard Worker            finally:
678*da0073e9SAndroid Build Coastguard Worker                cleanup_op("mylib::a")
679*da0073e9SAndroid Build Coastguard Worker                del lib
680*da0073e9SAndroid Build Coastguard Worker        for schema in expected_false:
681*da0073e9SAndroid Build Coastguard Worker            try:
682*da0073e9SAndroid Build Coastguard Worker                lib = torch.library.Library("mylib", "FRAGMENT")
683*da0073e9SAndroid Build Coastguard Worker                torch.library.define("mylib::a", schema, lib=lib)
684*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(
685*da0073e9SAndroid Build Coastguard Worker                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
686*da0073e9SAndroid Build Coastguard Worker                )
687*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
688*da0073e9SAndroid Build Coastguard Worker            finally:
689*da0073e9SAndroid Build Coastguard Worker                cleanup_op("mylib::a")
690*da0073e9SAndroid Build Coastguard Worker                del lib
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker    def test_auto_functionalize(self):
693*da0073e9SAndroid Build Coastguard Worker        try:
694*da0073e9SAndroid Build Coastguard Worker            lib = torch.library.Library("mylib", "FRAGMENT")
695*da0073e9SAndroid Build Coastguard Worker            torch.library.define(
696*da0073e9SAndroid Build Coastguard Worker                "mylib::foo",
697*da0073e9SAndroid Build Coastguard Worker                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
698*da0073e9SAndroid Build Coastguard Worker                tags=torch.Tag.pt2_compliant_tag,
699*da0073e9SAndroid Build Coastguard Worker                lib=lib,
700*da0073e9SAndroid Build Coastguard Worker            )
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker            @torch.library.impl("mylib::foo", "cpu", lib=lib)
703*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.disable
704*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, y, z, w, n):
705*da0073e9SAndroid Build Coastguard Worker                x.add_(y[0] + w)
706*da0073e9SAndroid Build Coastguard Worker                z.add_(y[1] + n)
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker            def f(x, y, z, n):
709*da0073e9SAndroid Build Coastguard Worker                torch.ops.mylib.foo(x, y, z, 2, n)
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
712*da0073e9SAndroid Build Coastguard Worker            y = (torch.randn(3), torch.randn(3))
713*da0073e9SAndroid Build Coastguard Worker            z = torch.randn(3)
714*da0073e9SAndroid Build Coastguard Worker            n = torch.randn(3)
715*da0073e9SAndroid Build Coastguard Worker            orig_args = (x, y, z, n)
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker            log_stream, ctx = logs_to_string(
720*da0073e9SAndroid Build Coastguard Worker                "torch._inductor.compile_fx", "post_grad_graphs"
721*da0073e9SAndroid Build Coastguard Worker            )
722*da0073e9SAndroid Build Coastguard Worker            with ctx():
723*da0073e9SAndroid Build Coastguard Worker                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker            post_grad_graphs = "\n".join(
726*da0073e9SAndroid Build Coastguard Worker                log_stream.getvalue().strip().split("\n")[3:]
727*da0073e9SAndroid Build Coastguard Worker            ).strip()
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker            # Check the graph under static shapes
730*da0073e9SAndroid Build Coastguard Worker            if torch._dynamo.config.assume_static_by_default:
731*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(
732*da0073e9SAndroid Build Coastguard Worker                    post_grad_graphs,
733*da0073e9SAndroid Build Coastguard Worker                    """\
734*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
735*da0073e9SAndroid Build Coastguard Worker        # No stacktrace found for following nodes
736*da0073e9SAndroid Build Coastguard Worker        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
737*da0073e9SAndroid Build Coastguard Worker        return ()""",
738*da0073e9SAndroid Build Coastguard Worker                )
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
741*da0073e9SAndroid Build Coastguard Worker            f(*eager_args)
742*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(compiled_args, eager_args)
743*da0073e9SAndroid Build Coastguard Worker        finally:
744*da0073e9SAndroid Build Coastguard Worker            cleanup_op("mylib::foo")
745*da0073e9SAndroid Build Coastguard Worker            del lib
746*da0073e9SAndroid Build Coastguard Worker
747*da0073e9SAndroid Build Coastguard Worker    def test_auto_functionalize_with_returns(self):
748*da0073e9SAndroid Build Coastguard Worker        try:
749*da0073e9SAndroid Build Coastguard Worker            lib = torch.library.Library("mylib", "FRAGMENT")
750*da0073e9SAndroid Build Coastguard Worker            torch.library.define(
751*da0073e9SAndroid Build Coastguard Worker                "mylib::foo",
752*da0073e9SAndroid Build Coastguard Worker                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
753*da0073e9SAndroid Build Coastguard Worker                tags=torch.Tag.pt2_compliant_tag,
754*da0073e9SAndroid Build Coastguard Worker                lib=lib,
755*da0073e9SAndroid Build Coastguard Worker            )
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker            @torch.library.impl("mylib::foo", "cpu", lib=lib)
758*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.disable
759*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, y, z, w, n):
760*da0073e9SAndroid Build Coastguard Worker                x.add_(y[0] + w)
761*da0073e9SAndroid Build Coastguard Worker                z.add_(y[1] + n)
762*da0073e9SAndroid Build Coastguard Worker                return y[0] + w, y[1] + n
763*da0073e9SAndroid Build Coastguard Worker
764*da0073e9SAndroid Build Coastguard Worker            @torch.library.impl_abstract("mylib::foo", lib=lib)
765*da0073e9SAndroid Build Coastguard Worker            def foo_abstract(x, y, z, w, n):
766*da0073e9SAndroid Build Coastguard Worker                return y[0] + w, y[1] + n
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker            def f(x, y, z, n):
769*da0073e9SAndroid Build Coastguard Worker                return torch.ops.mylib.foo(x, y, z, 2, n)
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
772*da0073e9SAndroid Build Coastguard Worker            y = (torch.randn(3), torch.randn(3))
773*da0073e9SAndroid Build Coastguard Worker            z = torch.randn(3)
774*da0073e9SAndroid Build Coastguard Worker            n = torch.randn(3)
775*da0073e9SAndroid Build Coastguard Worker            orig_args = (x, y, z, n)
776*da0073e9SAndroid Build Coastguard Worker
777*da0073e9SAndroid Build Coastguard Worker            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
778*da0073e9SAndroid Build Coastguard Worker            log_stream, ctx = logs_to_string(
779*da0073e9SAndroid Build Coastguard Worker                "torch._inductor.compile_fx", "post_grad_graphs"
780*da0073e9SAndroid Build Coastguard Worker            )
781*da0073e9SAndroid Build Coastguard Worker            with ctx():
782*da0073e9SAndroid Build Coastguard Worker                compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
783*da0073e9SAndroid Build Coastguard Worker                    *compiled_args
784*da0073e9SAndroid Build Coastguard Worker                )
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker            if torch._dynamo.config.assume_static_by_default:
787*da0073e9SAndroid Build Coastguard Worker                post_grad_graphs = "\n".join(
788*da0073e9SAndroid Build Coastguard Worker                    log_stream.getvalue().strip().split("\n")[3:]
789*da0073e9SAndroid Build Coastguard Worker                ).strip()
790*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(
791*da0073e9SAndroid Build Coastguard Worker                    post_grad_graphs,
792*da0073e9SAndroid Build Coastguard Worker                    """\
793*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
794*da0073e9SAndroid Build Coastguard Worker        # No stacktrace found for following nodes
795*da0073e9SAndroid Build Coastguard Worker        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
796*da0073e9SAndroid Build Coastguard Worker        getitem_4: "f32[3][1]cpu" = foo_default[0]
797*da0073e9SAndroid Build Coastguard Worker        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
798*da0073e9SAndroid Build Coastguard Worker        return (getitem_4, getitem_5)""",
799*da0073e9SAndroid Build Coastguard Worker                )
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
802*da0073e9SAndroid Build Coastguard Worker            eager_out = f(*eager_args)
803*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(compiled_args, eager_args)
804*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(compiled_out, eager_out)
805*da0073e9SAndroid Build Coastguard Worker        finally:
806*da0073e9SAndroid Build Coastguard Worker            cleanup_op("mylib::foo")
807*da0073e9SAndroid Build Coastguard Worker            del lib
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker    def test_auto_functionalize_on_view(self):
810*da0073e9SAndroid Build Coastguard Worker        try:
811*da0073e9SAndroid Build Coastguard Worker            lib = torch.library.Library("mylib", "FRAGMENT")
812*da0073e9SAndroid Build Coastguard Worker            torch.library.define(
813*da0073e9SAndroid Build Coastguard Worker                "mylib::foo",
814*da0073e9SAndroid Build Coastguard Worker                "(Tensor(a!) x) -> ()",
815*da0073e9SAndroid Build Coastguard Worker                tags=torch.Tag.pt2_compliant_tag,
816*da0073e9SAndroid Build Coastguard Worker                lib=lib,
817*da0073e9SAndroid Build Coastguard Worker            )
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker            @torch.library.impl("mylib::foo", "cpu", lib=lib)
820*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.disable
821*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x):
822*da0073e9SAndroid Build Coastguard Worker                x_np = x.detach().numpy()  # view
823*da0073e9SAndroid Build Coastguard Worker                np.sin(x_np, out=x_np)
824*da0073e9SAndroid Build Coastguard Worker                return
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
827*da0073e9SAndroid Build Coastguard Worker            expected = x.sin()
828*da0073e9SAndroid Build Coastguard Worker            torch.ops.mylib.foo(x)
829*da0073e9SAndroid Build Coastguard Worker            assert torch.allclose(x, expected)
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
832*da0073e9SAndroid Build Coastguard Worker            def f(x):
833*da0073e9SAndroid Build Coastguard Worker                x = x.clone()
834*da0073e9SAndroid Build Coastguard Worker                y = x[:]
835*da0073e9SAndroid Build Coastguard Worker                torch.ops.mylib.foo(y)
836*da0073e9SAndroid Build Coastguard Worker                return x
837*da0073e9SAndroid Build Coastguard Worker
838*da0073e9SAndroid Build Coastguard Worker            y = f(x)
839*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, x.sin())
840*da0073e9SAndroid Build Coastguard Worker        finally:
841*da0073e9SAndroid Build Coastguard Worker            cleanup_op("mylib::foo")
842*da0073e9SAndroid Build Coastguard Worker            del lib
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker    def test_auto_functionalize_optional(self):
845*da0073e9SAndroid Build Coastguard Worker        try:
846*da0073e9SAndroid Build Coastguard Worker            lib = torch.library.Library("mylib", "FRAGMENT")
847*da0073e9SAndroid Build Coastguard Worker            torch.library.define(
848*da0073e9SAndroid Build Coastguard Worker                "mylib::foo",
849*da0073e9SAndroid Build Coastguard Worker                "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
850*da0073e9SAndroid Build Coastguard Worker                tags=torch.Tag.pt2_compliant_tag,
851*da0073e9SAndroid Build Coastguard Worker                lib=lib,
852*da0073e9SAndroid Build Coastguard Worker            )
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker            @torch.library.impl("mylib::foo", "cpu", lib=lib)
855*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.disable
856*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, y, z, w, n):
857*da0073e9SAndroid Build Coastguard Worker                if x is not None:
858*da0073e9SAndroid Build Coastguard Worker                    x.add_(y[0] + w)
859*da0073e9SAndroid Build Coastguard Worker                if z is not None:
860*da0073e9SAndroid Build Coastguard Worker                    z.add_(y[1] + n)
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker            def f(x, y, z, n):
863*da0073e9SAndroid Build Coastguard Worker                torch.ops.mylib.foo(x, y, z, 2, n)
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker            x = None
866*da0073e9SAndroid Build Coastguard Worker            y = (torch.randn(3), torch.randn(3))
867*da0073e9SAndroid Build Coastguard Worker            z = torch.randn(3)
868*da0073e9SAndroid Build Coastguard Worker            n = torch.randn(3)
869*da0073e9SAndroid Build Coastguard Worker            orig_args = (x, y, z, n)
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
872*da0073e9SAndroid Build Coastguard Worker            log_stream, ctx = logs_to_string(
873*da0073e9SAndroid Build Coastguard Worker                "torch._inductor.compile_fx", "post_grad_graphs"
874*da0073e9SAndroid Build Coastguard Worker            )
875*da0073e9SAndroid Build Coastguard Worker            with ctx():
876*da0073e9SAndroid Build Coastguard Worker                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker            if torch._dynamo.config.assume_static_by_default:
879*da0073e9SAndroid Build Coastguard Worker                post_grad_graphs = "\n".join(
880*da0073e9SAndroid Build Coastguard Worker                    log_stream.getvalue().strip().split("\n")[3:]
881*da0073e9SAndroid Build Coastguard Worker                ).strip()
882*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(
883*da0073e9SAndroid Build Coastguard Worker                    post_grad_graphs,
884*da0073e9SAndroid Build Coastguard Worker                    """\
885*da0073e9SAndroid Build Coastguard Workerdef forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
886*da0073e9SAndroid Build Coastguard Worker        # No stacktrace found for following nodes
887*da0073e9SAndroid Build Coastguard Worker        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
888*da0073e9SAndroid Build Coastguard Worker        return ()""",
889*da0073e9SAndroid Build Coastguard Worker                )
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
892*da0073e9SAndroid Build Coastguard Worker            f(*eager_args)
893*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(compiled_args, eager_args)
894*da0073e9SAndroid Build Coastguard Worker        finally:
895*da0073e9SAndroid Build Coastguard Worker            cleanup_op("mylib::foo")
896*da0073e9SAndroid Build Coastguard Worker            del lib
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker    def test_shape_int_inplace_binops(self):
899*da0073e9SAndroid Build Coastguard Worker        def fn(x):
900*da0073e9SAndroid Build Coastguard Worker            p = x.shape[0]
901*da0073e9SAndroid Build Coastguard Worker            p += 2
902*da0073e9SAndroid Build Coastguard Worker            p -= 2
903*da0073e9SAndroid Build Coastguard Worker            p **= 2
904*da0073e9SAndroid Build Coastguard Worker            p /= 2
905*da0073e9SAndroid Build Coastguard Worker            p *= 2
906*da0073e9SAndroid Build Coastguard Worker            p //= 2
907*da0073e9SAndroid Build Coastguard Worker            p %= 2
908*da0073e9SAndroid Build Coastguard Worker            return x + p
909*da0073e9SAndroid Build Coastguard Worker
910*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
911*da0073e9SAndroid Build Coastguard Worker            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 10)
912*da0073e9SAndroid Build Coastguard Worker        )
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker    def test_int_shape_inplace_binops(self):
915*da0073e9SAndroid Build Coastguard Worker        def fn(x):
916*da0073e9SAndroid Build Coastguard Worker            p = x.shape[0]
917*da0073e9SAndroid Build Coastguard Worker            # Test reversal by putting constant first
918*da0073e9SAndroid Build Coastguard Worker            y = 2
919*da0073e9SAndroid Build Coastguard Worker            y += p
920*da0073e9SAndroid Build Coastguard Worker            y = 2
921*da0073e9SAndroid Build Coastguard Worker            y -= p
922*da0073e9SAndroid Build Coastguard Worker            y = 2
923*da0073e9SAndroid Build Coastguard Worker            y **= p
924*da0073e9SAndroid Build Coastguard Worker            y = 2
925*da0073e9SAndroid Build Coastguard Worker            y /= p
926*da0073e9SAndroid Build Coastguard Worker            y = 2
927*da0073e9SAndroid Build Coastguard Worker            y *= p
928*da0073e9SAndroid Build Coastguard Worker            y = 2
929*da0073e9SAndroid Build Coastguard Worker            y //= p
930*da0073e9SAndroid Build Coastguard Worker            y = 2
931*da0073e9SAndroid Build Coastguard Worker            y %= p
932*da0073e9SAndroid Build Coastguard Worker            return x + y
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
935*da0073e9SAndroid Build Coastguard Worker            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 4)
936*da0073e9SAndroid Build Coastguard Worker        )
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker    def test_int_int_comparisons(self):
939*da0073e9SAndroid Build Coastguard Worker        def fn(x):
940*da0073e9SAndroid Build Coastguard Worker            if 2 != 2:
941*da0073e9SAndroid Build Coastguard Worker                out = 1
942*da0073e9SAndroid Build Coastguard Worker            elif 2 < 1:
943*da0073e9SAndroid Build Coastguard Worker                out = 1
944*da0073e9SAndroid Build Coastguard Worker            elif 1 > 2:
945*da0073e9SAndroid Build Coastguard Worker                out = 1
946*da0073e9SAndroid Build Coastguard Worker            elif 1 >= 2:
947*da0073e9SAndroid Build Coastguard Worker                out = 1
948*da0073e9SAndroid Build Coastguard Worker            elif 2 <= 1:
949*da0073e9SAndroid Build Coastguard Worker                out = 1
950*da0073e9SAndroid Build Coastguard Worker            elif 2 == 2:
951*da0073e9SAndroid Build Coastguard Worker                out = 2
952*da0073e9SAndroid Build Coastguard Worker            else:
953*da0073e9SAndroid Build Coastguard Worker                out = 1
954*da0073e9SAndroid Build Coastguard Worker            return x + out
955*da0073e9SAndroid Build Coastguard Worker
956*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
957*da0073e9SAndroid Build Coastguard Worker
958*da0073e9SAndroid Build Coastguard Worker    def test_shape_int_comparisons(self):
959*da0073e9SAndroid Build Coastguard Worker        def fn(x):
960*da0073e9SAndroid Build Coastguard Worker            a = x.shape[0]
961*da0073e9SAndroid Build Coastguard Worker            # Ensure support for constant on right side
962*da0073e9SAndroid Build Coastguard Worker            if a != 10:
963*da0073e9SAndroid Build Coastguard Worker                out = 1
964*da0073e9SAndroid Build Coastguard Worker            elif a < 2:
965*da0073e9SAndroid Build Coastguard Worker                out = 1
966*da0073e9SAndroid Build Coastguard Worker            elif a > 12:
967*da0073e9SAndroid Build Coastguard Worker                out = 1
968*da0073e9SAndroid Build Coastguard Worker            elif a >= 12:
969*da0073e9SAndroid Build Coastguard Worker                out = 1
970*da0073e9SAndroid Build Coastguard Worker            elif a <= 2:
971*da0073e9SAndroid Build Coastguard Worker                out = 1
972*da0073e9SAndroid Build Coastguard Worker            elif a == 10:
973*da0073e9SAndroid Build Coastguard Worker                out = 2
974*da0073e9SAndroid Build Coastguard Worker            else:
975*da0073e9SAndroid Build Coastguard Worker                out = 1
976*da0073e9SAndroid Build Coastguard Worker            return x + out
977*da0073e9SAndroid Build Coastguard Worker
978*da0073e9SAndroid Build Coastguard Worker        # TODO: Test the guards maybe?
979*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker    def test_int_shape_comparisons(self):
982*da0073e9SAndroid Build Coastguard Worker        def fn(x):
983*da0073e9SAndroid Build Coastguard Worker            a = x.shape[0]
984*da0073e9SAndroid Build Coastguard Worker            # Ensure support for constant on left side
985*da0073e9SAndroid Build Coastguard Worker            if 10 != a:
986*da0073e9SAndroid Build Coastguard Worker                out = 1
987*da0073e9SAndroid Build Coastguard Worker            elif 12 < a:
988*da0073e9SAndroid Build Coastguard Worker                out = 1
989*da0073e9SAndroid Build Coastguard Worker            elif 2 > a:
990*da0073e9SAndroid Build Coastguard Worker                out = 1
991*da0073e9SAndroid Build Coastguard Worker            elif 2 >= a:
992*da0073e9SAndroid Build Coastguard Worker                out = 1
993*da0073e9SAndroid Build Coastguard Worker            elif 12 <= a:
994*da0073e9SAndroid Build Coastguard Worker                out = 1
995*da0073e9SAndroid Build Coastguard Worker            elif 10 == a:
996*da0073e9SAndroid Build Coastguard Worker                out = 2
997*da0073e9SAndroid Build Coastguard Worker            else:
998*da0073e9SAndroid Build Coastguard Worker                out = 1
999*da0073e9SAndroid Build Coastguard Worker            return x + out
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker        # TODO: Test the guards maybe?
1002*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
1003*da0073e9SAndroid Build Coastguard Worker
1004*da0073e9SAndroid Build Coastguard Worker    def test_param_shape_binops(self):
1005*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1006*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
1007*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1008*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.randn(15))
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1011*da0073e9SAndroid Build Coastguard Worker                # Test reversal by putting param shape arg first.
1012*da0073e9SAndroid Build Coastguard Worker                p = self.param.shape[0]
1013*da0073e9SAndroid Build Coastguard Worker                y = p - x.shape[0]
1014*da0073e9SAndroid Build Coastguard Worker                y = p + y
1015*da0073e9SAndroid Build Coastguard Worker                y = p * y
1016*da0073e9SAndroid Build Coastguard Worker                y = p % y
1017*da0073e9SAndroid Build Coastguard Worker                y = p**y
1018*da0073e9SAndroid Build Coastguard Worker                y = p // y
1019*da0073e9SAndroid Build Coastguard Worker                y = pow(p, y)
1020*da0073e9SAndroid Build Coastguard Worker                y = p / y
1021*da0073e9SAndroid Build Coastguard Worker                return x + y
1022*da0073e9SAndroid Build Coastguard Worker
1023*da0073e9SAndroid Build Coastguard Worker        counts = torch._dynamo.testing.CompileCounter()
1024*da0073e9SAndroid Build Coastguard Worker        mod = MyModule()
1025*da0073e9SAndroid Build Coastguard Worker        optimized_mod = torch._dynamo.optimize(counts, nopython=True)(mod)
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1028*da0073e9SAndroid Build Coastguard Worker        ref = mod(x)
1029*da0073e9SAndroid Build Coastguard Worker        res = optimized_mod(x)
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
1032*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counts.frame_count, 1)
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1035*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(counts.op_count, """1""")
1036*da0073e9SAndroid Build Coastguard Worker        else:
1037*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(counts.op_count, """11""")
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker    def test_user_defined_binop(self):
1040*da0073e9SAndroid Build Coastguard Worker        class MyClass:
1041*da0073e9SAndroid Build Coastguard Worker            def __init__(self, value):
1042*da0073e9SAndroid Build Coastguard Worker                self.value = value
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker            def __radd__(self, other):
1045*da0073e9SAndroid Build Coastguard Worker                return self.value + other
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Worker        def fn(x, c):
1048*da0073e9SAndroid Build Coastguard Worker            y = x.shape[0] + c
1049*da0073e9SAndroid Build Coastguard Worker            return x + y
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker        counts = torch._dynamo.testing.CompileCounter()
1052*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(counts)(fn)
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1055*da0073e9SAndroid Build Coastguard Worker        c = MyClass(4)
1056*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, c)
1057*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, c)
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
1060*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counts.frame_count, 1)
1061*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1062*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(counts.op_count, """1""")
1063*da0073e9SAndroid Build Coastguard Worker        else:
1064*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(counts.op_count, """4""")
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker    def test_user_defined_iter(self):
1067*da0073e9SAndroid Build Coastguard Worker        class Mod:
1068*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
1069*da0073e9SAndroid Build Coastguard Worker                self.a = [torch.randn(2, 2), torch.randn(2, 2)]
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
1072*da0073e9SAndroid Build Coastguard Worker                return iter(self.a)
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker        def f(mod):
1075*da0073e9SAndroid Build Coastguard Worker            ret = []
1076*da0073e9SAndroid Build Coastguard Worker            for x in mod:
1077*da0073e9SAndroid Build Coastguard Worker                ret.append(x + 1)
1078*da0073e9SAndroid Build Coastguard Worker            return ret
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
1081*da0073e9SAndroid Build Coastguard Worker        counts = torch._dynamo.testing.CompileCounter()
1082*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(counts, nopython=True)(f)
1083*da0073e9SAndroid Build Coastguard Worker        ref = f(mod)
1084*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(mod)
1085*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(mod)
1086*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(mod)
1087*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(mod)
1088*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
1089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counts.frame_count, 1)
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker        mod.a.append(torch.randn(2, 2))
1092*da0073e9SAndroid Build Coastguard Worker        # `for x in mod` is inlined, where iter(m.a) creates a guard on the list length of m.a
1093*da0073e9SAndroid Build Coastguard Worker        # Mutating length of mod.a causes a re-compilation.
1094*da0073e9SAndroid Build Coastguard Worker        ref2 = f(mod)
1095*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(mod)
1096*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(mod)
1097*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(mod)
1098*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(mod)
1099*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref2, res2))
1100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counts.frame_count, 2)
1101*da0073e9SAndroid Build Coastguard Worker
1102*da0073e9SAndroid Build Coastguard Worker    def test_compare_shapes_eq(self):
1103*da0073e9SAndroid Build Coastguard Worker        def compare_shapes(a, b, to_list):
1104*da0073e9SAndroid Build Coastguard Worker            x = list(a.unsqueeze(-1).shape) if to_list else a.shape
1105*da0073e9SAndroid Build Coastguard Worker            y = list(b.unsqueeze(-1).shape) if to_list else b.shape
1106*da0073e9SAndroid Build Coastguard Worker            if x == y:
1107*da0073e9SAndroid Build Coastguard Worker                return a + 1
1108*da0073e9SAndroid Build Coastguard Worker            else:
1109*da0073e9SAndroid Build Coastguard Worker                return a + 2
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker        # Test both ListVariable and ShapeVariable
1112*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1113*da0073e9SAndroid Build Coastguard Worker            self, lambda a, b: compare_shapes(a, b, to_list=True), 2
1114*da0073e9SAndroid Build Coastguard Worker        )
1115*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1116*da0073e9SAndroid Build Coastguard Worker            self, lambda a, b: compare_shapes(a, b, to_list=False), 2
1117*da0073e9SAndroid Build Coastguard Worker        )
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker    def test_compare_shapes_tuple_eq(self):
1120*da0073e9SAndroid Build Coastguard Worker        def compare_shapes(a, b):
1121*da0073e9SAndroid Build Coastguard Worker            x = tuple(a.unsqueeze(-1).shape)
1122*da0073e9SAndroid Build Coastguard Worker            y = tuple(b.unsqueeze(-1).shape)
1123*da0073e9SAndroid Build Coastguard Worker            if x == y:
1124*da0073e9SAndroid Build Coastguard Worker                return a + 1
1125*da0073e9SAndroid Build Coastguard Worker            else:
1126*da0073e9SAndroid Build Coastguard Worker                return a + 2
1127*da0073e9SAndroid Build Coastguard Worker
1128*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2)
1129*da0073e9SAndroid Build Coastguard Worker
1130*da0073e9SAndroid Build Coastguard Worker    def test_compare_shapes_tuple_neq(self):
1131*da0073e9SAndroid Build Coastguard Worker        def compare_shapes(a, b):
1132*da0073e9SAndroid Build Coastguard Worker            x = tuple(a.unsqueeze(-1).shape)
1133*da0073e9SAndroid Build Coastguard Worker            y = tuple(b.unsqueeze(-1).shape)
1134*da0073e9SAndroid Build Coastguard Worker            if x != y:
1135*da0073e9SAndroid Build Coastguard Worker                return a + 1
1136*da0073e9SAndroid Build Coastguard Worker            else:
1137*da0073e9SAndroid Build Coastguard Worker                return a + 2
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2)
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker    def test_compare_shapes_neq(self):
1142*da0073e9SAndroid Build Coastguard Worker        def compare_shapes(a, b, to_list):
1143*da0073e9SAndroid Build Coastguard Worker            x = list(a.unsqueeze(-1).shape) if to_list else a.shape
1144*da0073e9SAndroid Build Coastguard Worker            y = list(b.unsqueeze(-1).shape) if to_list else b.shape
1145*da0073e9SAndroid Build Coastguard Worker            if x != y:
1146*da0073e9SAndroid Build Coastguard Worker                return a + 1
1147*da0073e9SAndroid Build Coastguard Worker            else:
1148*da0073e9SAndroid Build Coastguard Worker                return a + 2
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker        # Test both ListVariable and ShapeVariable
1151*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1152*da0073e9SAndroid Build Coastguard Worker            self, lambda a, b: compare_shapes(a, b, to_list=True), 2
1153*da0073e9SAndroid Build Coastguard Worker        )
1154*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1155*da0073e9SAndroid Build Coastguard Worker            self, lambda a, b: compare_shapes(a, b, to_list=False), 2
1156*da0073e9SAndroid Build Coastguard Worker        )
1157*da0073e9SAndroid Build Coastguard Worker
1158*da0073e9SAndroid Build Coastguard Worker    def test_compare_shapes_with_constant(self):
1159*da0073e9SAndroid Build Coastguard Worker        def compare_shapes(a):
1160*da0073e9SAndroid Build Coastguard Worker            x = a.shape
1161*da0073e9SAndroid Build Coastguard Worker            if x[0] != 3:
1162*da0073e9SAndroid Build Coastguard Worker                return a * 4
1163*da0073e9SAndroid Build Coastguard Worker            return a * 3
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker        guard_failure = None
1166*da0073e9SAndroid Build Coastguard Worker
1167*da0073e9SAndroid Build Coastguard Worker        def guard_failures(failure):
1168*da0073e9SAndroid Build Coastguard Worker            nonlocal guard_failure
1169*da0073e9SAndroid Build Coastguard Worker            guard_failure = failure
1170*da0073e9SAndroid Build Coastguard Worker
1171*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(
1172*da0073e9SAndroid Build Coastguard Worker            "eager", nopython=True, guard_fail_fn=guard_failures
1173*da0073e9SAndroid Build Coastguard Worker        )(compare_shapes)
1174*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn([3, 4]))
1175*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn([4, 3]))
1176*da0073e9SAndroid Build Coastguard Worker        self.assertIn(
1177*da0073e9SAndroid Build Coastguard Worker            """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
1178*da0073e9SAndroid Build Coastguard Worker            guard_failure.reason,
1179*da0073e9SAndroid Build Coastguard Worker        )
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Worker    def test_builtin_abs(self):
1182*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
1183*da0073e9SAndroid Build Coastguard Worker            return abs(x) + abs(y)
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker        sample = torch.randn(10, 10)
1186*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker        for sample in [
1189*da0073e9SAndroid Build Coastguard Worker            (torch.randn(10, 10), torch.randn(10, 10)),
1190*da0073e9SAndroid Build Coastguard Worker            (-10, make_tensor(10, dtype=torch.int64, device="cpu")),
1191*da0073e9SAndroid Build Coastguard Worker            (-0.1, torch.randn(10)),
1192*da0073e9SAndroid Build Coastguard Worker        ]:
1193*da0073e9SAndroid Build Coastguard Worker            expect = fn(*sample)
1194*da0073e9SAndroid Build Coastguard Worker            actual = opt_fn(*sample)
1195*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker    def test_builtin_isinstance(self):
1198*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1199*da0073e9SAndroid Build Coastguard Worker            t = torch.arange(1, 3)
1200*da0073e9SAndroid Build Coastguard Worker            a = isinstance(x, torch.Tensor)
1201*da0073e9SAndroid Build Coastguard Worker            b = isinstance(t, torch.Tensor)
1202*da0073e9SAndroid Build Coastguard Worker            c = isinstance(x, int)
1203*da0073e9SAndroid Build Coastguard Worker            d = isinstance(3, int)
1204*da0073e9SAndroid Build Coastguard Worker            e = isinstance([1, 2, 3], list)
1205*da0073e9SAndroid Build Coastguard Worker            f = isinstance({"foo": 1, "bar": 2}, dict)
1206*da0073e9SAndroid Build Coastguard Worker            res = [a, b, c, d, e, f]
1207*da0073e9SAndroid Build Coastguard Worker            # Can't run yet due to other unimplemented instructions
1208*da0073e9SAndroid Build Coastguard Worker            # res += [isinstance(torch.nn.LazyLinear(2, 3), torch.nn.Linear)]
1209*da0073e9SAndroid Build Coastguard Worker            return res
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
1212*da0073e9SAndroid Build Coastguard Worker
1213*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(sys.version_info[:2] <= (3, 8), "Requires astunparse")
1214*da0073e9SAndroid Build Coastguard Worker    def test_cse_dict_guards(self):
1215*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1216*da0073e9SAndroid Build Coastguard Worker            ret = torch.zeros(3)
1217*da0073e9SAndroid Build Coastguard Worker            for v in x.values():
1218*da0073e9SAndroid Build Coastguard Worker                ret = ret + v
1219*da0073e9SAndroid Build Coastguard Worker            return ret
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.guards import build_guard_function, CLOSURE_VARS
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker        x = {3: torch.randn(3), 2: torch.randn(3), 4: torch.randn(3)}
1224*da0073e9SAndroid Build Coastguard Worker        _, guards = torch._dynamo.export(fn, x)
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Worker        code_lists = [c for g in guards for c in g.code_list or []]
1227*da0073e9SAndroid Build Coastguard Worker        _, pycode = build_guard_function(code_lists, [])
1228*da0073e9SAndroid Build Coastguard Worker        # Make sure we just call "list(dict.keys())" once
1229*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pycode.count("keys"), 1)
1230*da0073e9SAndroid Build Coastguard Worker
1231*da0073e9SAndroid Build Coastguard Worker    def test_sys_modules(self):
1232*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
1233*da0073e9SAndroid Build Coastguard Worker            mod_a = sys.modules.get("aaaaaaaa")
1234*da0073e9SAndroid Build Coastguard Worker            assert mod_a is None
1235*da0073e9SAndroid Build Coastguard Worker            assert "bbbbbbbb" not in sys.modules
1236*da0073e9SAndroid Build Coastguard Worker
1237*da0073e9SAndroid Build Coastguard Worker            assert "operator" in sys.modules
1238*da0073e9SAndroid Build Coastguard Worker            operator = sys.modules["operator"]
1239*da0073e9SAndroid Build Coastguard Worker            builtins = sys.modules.get("builtins")
1240*da0073e9SAndroid Build Coastguard Worker            operator2 = sys.modules.get("cccccccc", operator)
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker            return operator.add(x, y), operator2.neg(builtins.abs(x))
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn, 2, expected_ops=3)
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
1247*da0073e9SAndroid Build Coastguard Worker        _, guards = torch._dynamo.export(fn, x, x)
1248*da0073e9SAndroid Build Coastguard Worker        guard_code = []
1249*da0073e9SAndroid Build Coastguard Worker        for guard in guards:
1250*da0073e9SAndroid Build Coastguard Worker            if guard.code_list:
1251*da0073e9SAndroid Build Coastguard Worker                guard_code += guard.code_list
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker        # Filter out id-matches that won't reproduce run to run
1254*da0073e9SAndroid Build Coastguard Worker        guard_code = filter(
1255*da0073e9SAndroid Build Coastguard Worker            lambda line: "id" not in line and "lookup_backend" not in line,
1256*da0073e9SAndroid Build Coastguard Worker            sorted(guard_code),
1257*da0073e9SAndroid Build Coastguard Worker        )
1258*da0073e9SAndroid Build Coastguard Worker        guard_code_str = "\n".join(guard_code)
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker        for line in """\
1261*da0073e9SAndroid Build Coastguard Worker2 <= L['x'].size()[0]
1262*da0073e9SAndroid Build Coastguard WorkerL['x'] is L['y']
1263*da0073e9SAndroid Build Coastguard WorkerL['x'].ndimension() == 2
1264*da0073e9SAndroid Build Coastguard WorkerL['x'].requires_grad == False
1265*da0073e9SAndroid Build Coastguard WorkerL['x'].size()[1] == L['x'].size()[0]
1266*da0073e9SAndroid Build Coastguard WorkerL['x'].storage_offset() == 0
1267*da0073e9SAndroid Build Coastguard Worker___dict_contains('builtins', G['sys'].modules)
1268*da0073e9SAndroid Build Coastguard Worker___dict_contains('operator', G['sys'].modules)
1269*da0073e9SAndroid Build Coastguard Worker___dict_contains('operator', G['sys'].modules)
1270*da0073e9SAndroid Build Coastguard Workerhasattr(L['x'], '_dynamo_dynamic_indices') == False
1271*da0073e9SAndroid Build Coastguard Workernot ___dict_contains('aaaaaaaa', G['sys'].modules)
1272*da0073e9SAndroid Build Coastguard Workernot ___dict_contains('bbbbbbbb', G['sys'].modules)
1273*da0073e9SAndroid Build Coastguard Workernot ___dict_contains('cccccccc', G['sys'].modules)
1274*da0073e9SAndroid Build Coastguard Workerstr(L['x'].device) == 'cpu'
1275*da0073e9SAndroid Build Coastguard Workerstr(L['x'].dtype) == 'torch.float32'
1276*da0073e9SAndroid Build Coastguard Workerutils_device.CURRENT_DEVICE == None""".split(
1277*da0073e9SAndroid Build Coastguard Worker            "\n"
1278*da0073e9SAndroid Build Coastguard Worker        ):
1279*da0073e9SAndroid Build Coastguard Worker            self.assertIn(
1280*da0073e9SAndroid Build Coastguard Worker                line,
1281*da0073e9SAndroid Build Coastguard Worker                guard_code_str,
1282*da0073e9SAndroid Build Coastguard Worker            )
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker    def test_fold(self):
1285*da0073e9SAndroid Build Coastguard Worker        def fn(a):
1286*da0073e9SAndroid Build Coastguard Worker            return a + math.sqrt(63)
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
1289*da0073e9SAndroid Build Coastguard Worker
1290*da0073e9SAndroid Build Coastguard Worker    def test_getattr_dict(self):
1291*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1292*da0073e9SAndroid Build Coastguard Worker            from torch.masked.maskedtensor._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker            return x * len(_MASKEDTENSOR_FUNCTION_TABLE)
1295*da0073e9SAndroid Build Coastguard Worker
1296*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(5)
1297*da0073e9SAndroid Build Coastguard Worker        r1 = fn(i)
1298*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1299*da0073e9SAndroid Build Coastguard Worker        r2 = opt_fn(i)
1300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1, r2)
1301*da0073e9SAndroid Build Coastguard Worker
1302*da0073e9SAndroid Build Coastguard Worker    def test_shape_unpack(self):
1303*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1304*da0073e9SAndroid Build Coastguard Worker            a, b = x.size()
1305*da0073e9SAndroid Build Coastguard Worker            return x * b
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(5, 10)
1308*da0073e9SAndroid Build Coastguard Worker        r1 = fn(i)
1309*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1310*da0073e9SAndroid Build Coastguard Worker        r2 = opt_fn(i)
1311*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(r1, r2))
1312*da0073e9SAndroid Build Coastguard Worker
1313*da0073e9SAndroid Build Coastguard Worker    def test_typing_dict(self):
1314*da0073e9SAndroid Build Coastguard Worker        def fn(d):
1315*da0073e9SAndroid Build Coastguard Worker            return d[T]
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker        d = {T: torch.randn(3)}
1318*da0073e9SAndroid Build Coastguard Worker        r1 = fn(d)
1319*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1320*da0073e9SAndroid Build Coastguard Worker        r2 = opt_fn(d)
1321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1, r2)
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker    def test_tensor_iter(self):
1324*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1325*da0073e9SAndroid Build Coastguard Worker            for y in x:
1326*da0073e9SAndroid Build Coastguard Worker                y.add_(1.0)
1327*da0073e9SAndroid Build Coastguard Worker            return y
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1330*da0073e9SAndroid Build Coastguard Worker            self,
1331*da0073e9SAndroid Build Coastguard Worker            fn,
1332*da0073e9SAndroid Build Coastguard Worker            1,
1333*da0073e9SAndroid Build Coastguard Worker            expected_ops=20,
1334*da0073e9SAndroid Build Coastguard Worker        )
1335*da0073e9SAndroid Build Coastguard Worker
1336*da0073e9SAndroid Build Coastguard Worker    def test_empty_list(self):
1337*da0073e9SAndroid Build Coastguard Worker        def fn(x, ll):
1338*da0073e9SAndroid Build Coastguard Worker            if len(ll) == 0 and not ll and ll is not None:
1339*da0073e9SAndroid Build Coastguard Worker                return x + 1
1340*da0073e9SAndroid Build Coastguard Worker
1341*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(5, 10)
1342*da0073e9SAndroid Build Coastguard Worker        r1 = fn(i, [])
1343*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1344*da0073e9SAndroid Build Coastguard Worker        r2 = opt_fn(i, [])
1345*da0073e9SAndroid Build Coastguard Worker        r3 = opt_fn(i, tuple())
1346*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(r1, r2))
1347*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(r1, r3))
1348*da0073e9SAndroid Build Coastguard Worker
1349*da0073e9SAndroid Build Coastguard Worker    def test_min_max_over_iterable(self):
1350*da0073e9SAndroid Build Coastguard Worker        def get_test_fn(func):
1351*da0073e9SAndroid Build Coastguard Worker            def _fn(a, b, func=func):
1352*da0073e9SAndroid Build Coastguard Worker                # try all of list, iterator, tuple, vararg.
1353*da0073e9SAndroid Build Coastguard Worker                lst = [a.shape[0] + 1, 8, a.shape[0]]
1354*da0073e9SAndroid Build Coastguard Worker                x = func(lst)
1355*da0073e9SAndroid Build Coastguard Worker                y = func(iter(lst))
1356*da0073e9SAndroid Build Coastguard Worker                z = func(tuple(lst))
1357*da0073e9SAndroid Build Coastguard Worker                w = func(*lst)
1358*da0073e9SAndroid Build Coastguard Worker                return a + (x + y + z + w)
1359*da0073e9SAndroid Build Coastguard Worker
1360*da0073e9SAndroid Build Coastguard Worker            return _fn
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1363*da0073e9SAndroid Build Coastguard Worker            self,
1364*da0073e9SAndroid Build Coastguard Worker            get_test_fn(func=min),
1365*da0073e9SAndroid Build Coastguard Worker            2,
1366*da0073e9SAndroid Build Coastguard Worker            expected_ops=1,
1367*da0073e9SAndroid Build Coastguard Worker            expected_ops_dynamic=ifdynstaticdefault(1, 14),
1368*da0073e9SAndroid Build Coastguard Worker        )
1369*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1370*da0073e9SAndroid Build Coastguard Worker            self,
1371*da0073e9SAndroid Build Coastguard Worker            get_test_fn(func=max),
1372*da0073e9SAndroid Build Coastguard Worker            2,
1373*da0073e9SAndroid Build Coastguard Worker            expected_ops=1,
1374*da0073e9SAndroid Build Coastguard Worker            expected_ops_dynamic=ifdynstaticdefault(1, 17),
1375*da0073e9SAndroid Build Coastguard Worker        )
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
1378*da0073e9SAndroid Build Coastguard Worker    def test_torch_check(self):
1379*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1380*da0073e9SAndroid Build Coastguard Worker
1381*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnts, fullgraph=True)
1382*da0073e9SAndroid Build Coastguard Worker        def f(x):
1383*da0073e9SAndroid Build Coastguard Worker            y = x.item()
1384*da0073e9SAndroid Build Coastguard Worker            torch._check(y >= 0)
1385*da0073e9SAndroid Build Coastguard Worker            return torch.arange(0, y)
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker        f(torch.tensor([3]))
1388*da0073e9SAndroid Build Coastguard Worker        f(torch.tensor([4]))
1389*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1390*da0073e9SAndroid Build Coastguard Worker
1391*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
1392*da0073e9SAndroid Build Coastguard Worker    def test_torch_check_symbolic_shape_rel(self):
1393*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnts, fullgraph=True)
1396*da0073e9SAndroid Build Coastguard Worker        def f(x):
1397*da0073e9SAndroid Build Coastguard Worker            y = x.item()
1398*da0073e9SAndroid Build Coastguard Worker            torch._check(x.shape[0] == 1)
1399*da0073e9SAndroid Build Coastguard Worker            torch._check(x.shape[0] != 2)
1400*da0073e9SAndroid Build Coastguard Worker            torch._check(x.shape[0] >= 0)
1401*da0073e9SAndroid Build Coastguard Worker            torch._check(x.shape[0] > 0)
1402*da0073e9SAndroid Build Coastguard Worker            torch._check(x.shape[0] < 4)
1403*da0073e9SAndroid Build Coastguard Worker            torch._check(x.shape[0] <= 3)
1404*da0073e9SAndroid Build Coastguard Worker            return torch.arange(0, y)
1405*da0073e9SAndroid Build Coastguard Worker
1406*da0073e9SAndroid Build Coastguard Worker        f(torch.tensor([3]))
1407*da0073e9SAndroid Build Coastguard Worker        f(torch.tensor([4]))
1408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
1411*da0073e9SAndroid Build Coastguard Worker    # Translation validation changes the exception type, don't run with it
1412*da0073e9SAndroid Build Coastguard Worker    @torch.fx.experimental._config.patch(translation_validation=False)
1413*da0073e9SAndroid Build Coastguard Worker    def test_torch_check_is_size(self):
1414*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnts, fullgraph=True)
1417*da0073e9SAndroid Build Coastguard Worker        def f(x):
1418*da0073e9SAndroid Build Coastguard Worker            y = x.item()
1419*da0073e9SAndroid Build Coastguard Worker            torch._check_is_size(y)
1420*da0073e9SAndroid Build Coastguard Worker            # Cannot conditional on unbacked SymInt
1421*da0073e9SAndroid Build Coastguard Worker            if y == 0:
1422*da0073e9SAndroid Build Coastguard Worker                assert False
1423*da0073e9SAndroid Build Coastguard Worker            else:
1424*da0073e9SAndroid Build Coastguard Worker                return torch.arange(0, y)
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3])))
1427*da0073e9SAndroid Build Coastguard Worker
1428*da0073e9SAndroid Build Coastguard Worker    def test_assert(self):
1429*da0073e9SAndroid Build Coastguard Worker        @torch.compile
1430*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
1431*da0073e9SAndroid Build Coastguard Worker            assert x.shape != x.shape
1432*da0073e9SAndroid Build Coastguard Worker
1433*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
1434*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(10)
1435*da0073e9SAndroid Build Coastguard Worker            fn1(a)
1436*da0073e9SAndroid Build Coastguard Worker
1437*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
1438*da0073e9SAndroid Build Coastguard Worker            assert x.shape == x.shape
1439*da0073e9SAndroid Build Coastguard Worker            return x.abs()
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn=fn2, nargs=1, expected_ops=1)
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker    def test_config_obj(self):
1444*da0073e9SAndroid Build Coastguard Worker        class Cfg:
1445*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
1446*da0073e9SAndroid Build Coastguard Worker                self.val = 0.5
1447*da0073e9SAndroid Build Coastguard Worker                self.count = 3
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker        def fn(x, cfg):
1450*da0073e9SAndroid Build Coastguard Worker            for i in range(cfg.count):
1451*da0073e9SAndroid Build Coastguard Worker                x = x + cfg.val
1452*da0073e9SAndroid Build Coastguard Worker            return x
1453*da0073e9SAndroid Build Coastguard Worker
1454*da0073e9SAndroid Build Coastguard Worker        cfg1 = Cfg()
1455*da0073e9SAndroid Build Coastguard Worker        cfg1.val = 1.0
1456*da0073e9SAndroid Build Coastguard Worker        cfg2 = Cfg()
1457*da0073e9SAndroid Build Coastguard Worker        v = torch.zeros(1)
1458*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1459*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1460*da0073e9SAndroid Build Coastguard Worker        v = opt_fn(v, cfg1)  # 3
1461*da0073e9SAndroid Build Coastguard Worker        v = opt_fn(v, cfg2)  # 4.5
1462*da0073e9SAndroid Build Coastguard Worker        cfg2.count = 1
1463*da0073e9SAndroid Build Coastguard Worker        v = opt_fn(v, cfg2)  # 5
1464*da0073e9SAndroid Build Coastguard Worker        cfg2.val = 2.0
1465*da0073e9SAndroid Build Coastguard Worker        v = opt_fn(v, cfg2)  # 7
1466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[0], 7)
1467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 8)
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker    def test_config_getattr_default(self):
1470*da0073e9SAndroid Build Coastguard Worker        class Cfg:
1471*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
1472*da0073e9SAndroid Build Coastguard Worker                self.val = 0.5
1473*da0073e9SAndroid Build Coastguard Worker                self.count = 10
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker        def fn(x, cfg):
1476*da0073e9SAndroid Build Coastguard Worker            if getattr(cfg, "just_add_7", False):
1477*da0073e9SAndroid Build Coastguard Worker                return x + 7
1478*da0073e9SAndroid Build Coastguard Worker            for i in range(cfg.count):
1479*da0073e9SAndroid Build Coastguard Worker                x = x + cfg.val
1480*da0073e9SAndroid Build Coastguard Worker            return x
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker        cfg1 = Cfg()
1483*da0073e9SAndroid Build Coastguard Worker        v = torch.zeros(1)
1484*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1485*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1486*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v, cfg1)[0], 5)
1487*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v, cfg1)[0], 5)
1488*da0073e9SAndroid Build Coastguard Worker        cfg1.just_add_7 = True
1489*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v, cfg1)[0], 7)
1490*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v, cfg1)[0], 7)
1491*da0073e9SAndroid Build Coastguard Worker        cfg1.just_add_7 = False
1492*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v, cfg1)[0], 5)
1493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v, cfg1)[0], 5)
1494*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 3)
1495*da0073e9SAndroid Build Coastguard Worker
1496*da0073e9SAndroid Build Coastguard Worker    def test_size_input(self):
1497*da0073e9SAndroid Build Coastguard Worker        def fn(x, s):
1498*da0073e9SAndroid Build Coastguard Worker            a, b = s
1499*da0073e9SAndroid Build Coastguard Worker            return x + (a - b)
1500*da0073e9SAndroid Build Coastguard Worker
1501*da0073e9SAndroid Build Coastguard Worker        v = torch.zeros(10, 20)
1502*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1503*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1504*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v, v.size())[0, 0], -10)
1505*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10)
1506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10)
1507*da0073e9SAndroid Build Coastguard Worker        # One recompile per differing input type
1508*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 3)
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker    def test_cell_output1(self):
1511*da0073e9SAndroid Build Coastguard Worker        out = None
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1514*da0073e9SAndroid Build Coastguard Worker            nonlocal out
1515*da0073e9SAndroid Build Coastguard Worker            out = a + b * 10
1516*da0073e9SAndroid Build Coastguard Worker
1517*da0073e9SAndroid Build Coastguard Worker        v = torch.Tensor([100])
1518*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1519*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1520*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(opt_fn(v, v))
1521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out[0], 1100)
1522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker    def test_cell_output2(self):
1525*da0073e9SAndroid Build Coastguard Worker        out = None
1526*da0073e9SAndroid Build Coastguard Worker
1527*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1528*da0073e9SAndroid Build Coastguard Worker            nonlocal out
1529*da0073e9SAndroid Build Coastguard Worker            c = unsupported(a, b)
1530*da0073e9SAndroid Build Coastguard Worker            out = a + b * 10 + c
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker        v = torch.Tensor([100])
1533*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1534*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1535*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(opt_fn(v, v))
1536*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out[0], 1200)
1537*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 3)
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker    def test_return_nested_function(self):
1540*da0073e9SAndroid Build Coastguard Worker        out = None
1541*da0073e9SAndroid Build Coastguard Worker
1542*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1543*da0073e9SAndroid Build Coastguard Worker            nonlocal out
1544*da0073e9SAndroid Build Coastguard Worker            c = a + b
1545*da0073e9SAndroid Build Coastguard Worker            d = a + 1.0
1546*da0073e9SAndroid Build Coastguard Worker
1547*da0073e9SAndroid Build Coastguard Worker            def fn2(f: int = 7, g: float = 9.0):
1548*da0073e9SAndroid Build Coastguard Worker                nonlocal out
1549*da0073e9SAndroid Build Coastguard Worker                out = a + b * 10
1550*da0073e9SAndroid Build Coastguard Worker                return c * f - d * g
1551*da0073e9SAndroid Build Coastguard Worker
1552*da0073e9SAndroid Build Coastguard Worker            return fn2
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker        v1 = torch.Tensor([100])
1555*da0073e9SAndroid Build Coastguard Worker        v2 = torch.Tensor([200])
1556*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1557*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1558*da0073e9SAndroid Build Coastguard Worker        opt_fn_ret = torch._dynamo.optimize(cnts)(opt_fn(v1, v2))
1559*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn_ret(1.5)[0], -459)
1560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out[0], 2100)
1561*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
1562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 7)
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Worker    def test_tensor_dict1(self):
1565*da0073e9SAndroid Build Coastguard Worker        def fn(inputs):
1566*da0073e9SAndroid Build Coastguard Worker            return inputs["a"] - inputs["b"] * 1.5
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Worker        v1 = torch.Tensor([100])
1569*da0073e9SAndroid Build Coastguard Worker        v2 = torch.Tensor([200])
1570*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1571*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
1572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn({"a": v1, "b": v2})[0], -200)
1573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
1575*da0073e9SAndroid Build Coastguard Worker
1576*da0073e9SAndroid Build Coastguard Worker    def test_tensor_dict3(self):
1577*da0073e9SAndroid Build Coastguard Worker        def fn(inputs_a, inputs_b):
1578*da0073e9SAndroid Build Coastguard Worker            total = torch.zeros(1)
1579*da0073e9SAndroid Build Coastguard Worker            input_keys = inputs_a.keys() | inputs_b.keys()
1580*da0073e9SAndroid Build Coastguard Worker            for k in input_keys:
1581*da0073e9SAndroid Build Coastguard Worker                if k in inputs_a:
1582*da0073e9SAndroid Build Coastguard Worker                    total += inputs_a[k]
1583*da0073e9SAndroid Build Coastguard Worker                if k in inputs_b:
1584*da0073e9SAndroid Build Coastguard Worker                    total += inputs_b[k]
1585*da0073e9SAndroid Build Coastguard Worker            return total
1586*da0073e9SAndroid Build Coastguard Worker
1587*da0073e9SAndroid Build Coastguard Worker        v1 = torch.Tensor([100])
1588*da0073e9SAndroid Build Coastguard Worker        v2 = torch.Tensor([200])
1589*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1590*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
1591*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1592*da0073e9SAndroid Build Coastguard Worker            opt_fn({"a": v1, "b": v2}, {"b": v1, "c": v2}),
1593*da0073e9SAndroid Build Coastguard Worker            fn({"a": v1, "b": v2}, {"b": v1, "c": v2}),
1594*da0073e9SAndroid Build Coastguard Worker        )
1595*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1596*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 5)
1597*da0073e9SAndroid Build Coastguard Worker
1598*da0073e9SAndroid Build Coastguard Worker    def test_tensor_dict2(self):
1599*da0073e9SAndroid Build Coastguard Worker        def fn1(inputs):
1600*da0073e9SAndroid Build Coastguard Worker            total = torch.zeros(1)
1601*da0073e9SAndroid Build Coastguard Worker            for k, v in inputs.items():
1602*da0073e9SAndroid Build Coastguard Worker                total += v
1603*da0073e9SAndroid Build Coastguard Worker            return total
1604*da0073e9SAndroid Build Coastguard Worker
1605*da0073e9SAndroid Build Coastguard Worker        def fn2(inputs):
1606*da0073e9SAndroid Build Coastguard Worker            total = torch.zeros(1)
1607*da0073e9SAndroid Build Coastguard Worker            for v in inputs.values():
1608*da0073e9SAndroid Build Coastguard Worker                total += v
1609*da0073e9SAndroid Build Coastguard Worker            return total
1610*da0073e9SAndroid Build Coastguard Worker
1611*da0073e9SAndroid Build Coastguard Worker        def fn3(inputs):
1612*da0073e9SAndroid Build Coastguard Worker            total = torch.zeros(1)
1613*da0073e9SAndroid Build Coastguard Worker            for k in inputs.keys():
1614*da0073e9SAndroid Build Coastguard Worker                total += inputs[k]
1615*da0073e9SAndroid Build Coastguard Worker            return total
1616*da0073e9SAndroid Build Coastguard Worker
1617*da0073e9SAndroid Build Coastguard Worker        v1 = torch.Tensor([100])
1618*da0073e9SAndroid Build Coastguard Worker        v2 = torch.Tensor([200])
1619*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1620*da0073e9SAndroid Build Coastguard Worker        opt_fn1 = torch._dynamo.optimize(cnts, nopython=True)(fn1)
1621*da0073e9SAndroid Build Coastguard Worker        opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
1622*da0073e9SAndroid Build Coastguard Worker        opt_fn3 = torch._dynamo.optimize(cnts, nopython=True)(fn3)
1623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn1({"a": v1, "b": v2})[0], 300)
1624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn2({"a": v1, "b": v2})[0], 300)
1625*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn3({"a": v1, "b": v2})[0], 300)
1626*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 3)
1627*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 9)
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker    def test_dictcomp(self):
1630*da0073e9SAndroid Build Coastguard Worker        def fn1(inputs):
1631*da0073e9SAndroid Build Coastguard Worker            return {k: v + 1 for k, v in inputs.items()}
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker        v1 = torch.Tensor([100])
1634*da0073e9SAndroid Build Coastguard Worker        v2 = torch.Tensor([200])
1635*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1636*da0073e9SAndroid Build Coastguard Worker        opt_fn1 = torch._dynamo.optimize(cnts)(fn1)
1637*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn1({"a": v1, "b": v2})["a"], 101)
1638*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn1({"a": v1, "b": v2})["b"], 201)
1639*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1640*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
1641*da0073e9SAndroid Build Coastguard Worker
1642*da0073e9SAndroid Build Coastguard Worker    def test_listcomp(self):
1643*da0073e9SAndroid Build Coastguard Worker        def fn2(inputs):
1644*da0073e9SAndroid Build Coastguard Worker            return torch.sum(torch.cat([v + 1 for k, v in inputs.items()], 0))
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker        v1 = torch.Tensor([100])
1647*da0073e9SAndroid Build Coastguard Worker        v2 = torch.Tensor([200])
1648*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1649*da0073e9SAndroid Build Coastguard Worker        opt_fn2 = torch._dynamo.optimize(cnts)(fn2)
1650*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn2({"a": v1, "b": v2}), 302)
1651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1652*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 4)
1653*da0073e9SAndroid Build Coastguard Worker
1654*da0073e9SAndroid Build Coastguard Worker    def test_is_floating_point(self):
1655*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1656*da0073e9SAndroid Build Coastguard Worker            x = a + 1.0
1657*da0073e9SAndroid Build Coastguard Worker            if torch.is_floating_point(b):
1658*da0073e9SAndroid Build Coastguard Worker                x = x + b
1659*da0073e9SAndroid Build Coastguard Worker            return x + 2.0
1660*da0073e9SAndroid Build Coastguard Worker
1661*da0073e9SAndroid Build Coastguard Worker        return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
1662*da0073e9SAndroid Build Coastguard Worker
1663*da0073e9SAndroid Build Coastguard Worker    def test_is_floating_point2(self):
1664*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1665*da0073e9SAndroid Build Coastguard Worker            x = a + 1.0
1666*da0073e9SAndroid Build Coastguard Worker            if b.is_floating_point():
1667*da0073e9SAndroid Build Coastguard Worker                x = x + b
1668*da0073e9SAndroid Build Coastguard Worker            return x + 2.0
1669*da0073e9SAndroid Build Coastguard Worker
1670*da0073e9SAndroid Build Coastguard Worker        return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
1671*da0073e9SAndroid Build Coastguard Worker
1672*da0073e9SAndroid Build Coastguard Worker    def test_is_tensor(self):
1673*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1674*da0073e9SAndroid Build Coastguard Worker            x = a + 1.0
1675*da0073e9SAndroid Build Coastguard Worker            if torch.is_tensor(b):
1676*da0073e9SAndroid Build Coastguard Worker                x = x + b
1677*da0073e9SAndroid Build Coastguard Worker            return x + 2.0
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker        return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker    def test_is_tensor2(self):
1682*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1683*da0073e9SAndroid Build Coastguard Worker            if torch.is_tensor(x):
1684*da0073e9SAndroid Build Coastguard Worker                return x + 1
1685*da0073e9SAndroid Build Coastguard Worker            else:
1686*da0073e9SAndroid Build Coastguard Worker                return torch.ones([2, 3])
1687*da0073e9SAndroid Build Coastguard Worker
1688*da0073e9SAndroid Build Coastguard Worker        x1 = {"input": torch.rand(2, 3)}
1689*da0073e9SAndroid Build Coastguard Worker        x2 = torch.rand(2, 3)
1690*da0073e9SAndroid Build Coastguard Worker        ref1 = fn(x1)
1691*da0073e9SAndroid Build Coastguard Worker        ref2 = fn(x2)
1692*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1693*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x1)
1694*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x2)
1695*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref1, res1)
1696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref2, res2)
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker    def test_numel(self):
1699*da0073e9SAndroid Build Coastguard Worker        def fn(a):
1700*da0073e9SAndroid Build Coastguard Worker            return (a + a.numel() + torch.numel(a), a + a.nelement())
1701*da0073e9SAndroid Build Coastguard Worker
1702*da0073e9SAndroid Build Coastguard Worker        return torch._dynamo.testing.standard_test(
1703*da0073e9SAndroid Build Coastguard Worker            self,
1704*da0073e9SAndroid Build Coastguard Worker            fn=fn,
1705*da0073e9SAndroid Build Coastguard Worker            nargs=1,
1706*da0073e9SAndroid Build Coastguard Worker            expected_ops=3,
1707*da0073e9SAndroid Build Coastguard Worker            expected_ops_dynamic=ifdynstaticdefault(3, 6),
1708*da0073e9SAndroid Build Coastguard Worker        )
1709*da0073e9SAndroid Build Coastguard Worker
1710*da0073e9SAndroid Build Coastguard Worker    def test_pair(self):
1711*da0073e9SAndroid Build Coastguard Worker        def fn(a):
1712*da0073e9SAndroid Build Coastguard Worker            return (
1713*da0073e9SAndroid Build Coastguard Worker                torch.zeros(torch.nn.modules.utils._pair(a.size()))
1714*da0073e9SAndroid Build Coastguard Worker                + a
1715*da0073e9SAndroid Build Coastguard Worker                + torch.ones(torch.nn.modules.utils._ntuple(3)(3)).sum()
1716*da0073e9SAndroid Build Coastguard Worker            )
1717*da0073e9SAndroid Build Coastguard Worker
1718*da0073e9SAndroid Build Coastguard Worker        return torch._dynamo.testing.standard_test(
1719*da0073e9SAndroid Build Coastguard Worker            self,
1720*da0073e9SAndroid Build Coastguard Worker            fn=fn,
1721*da0073e9SAndroid Build Coastguard Worker            nargs=1,
1722*da0073e9SAndroid Build Coastguard Worker            expected_ops=5,
1723*da0073e9SAndroid Build Coastguard Worker            expected_ops_dynamic=ifdynstaticdefault(5, 8),
1724*da0073e9SAndroid Build Coastguard Worker        )
1725*da0073e9SAndroid Build Coastguard Worker
1726*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
1727*da0073e9SAndroid Build Coastguard Worker    def test_tensor_item_capture(self):
1728*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1729*da0073e9SAndroid Build Coastguard Worker            return (a + b).sum().item()
1730*da0073e9SAndroid Build Coastguard Worker
1731*da0073e9SAndroid Build Coastguard Worker        v1 = torch.randn((10, 10))
1732*da0073e9SAndroid Build Coastguard Worker        v2 = torch.randn((10, 10))
1733*da0073e9SAndroid Build Coastguard Worker        correct = fn(v1, v2)
1734*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1735*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1736*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v1, v2), correct)
1737*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1738*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 4)
1739*da0073e9SAndroid Build Coastguard Worker
1740*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
1741*da0073e9SAndroid Build Coastguard Worker    def test_tensor_item_no_capture(self):
1742*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1743*da0073e9SAndroid Build Coastguard Worker            return (a + b).sum().item()
1744*da0073e9SAndroid Build Coastguard Worker
1745*da0073e9SAndroid Build Coastguard Worker        v1 = torch.randn((10, 10))
1746*da0073e9SAndroid Build Coastguard Worker        v2 = torch.randn((10, 10))
1747*da0073e9SAndroid Build Coastguard Worker        correct = fn(v1, v2)
1748*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1749*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1750*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v1, v2), correct)
1751*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1752*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple1(self):
1755*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1756*da0073e9SAndroid Build Coastguard Worker            tmp = mytuple(a, b, a + b)
1757*da0073e9SAndroid Build Coastguard Worker            return mytuple(tmp.a, tmp[1], tmp.ab + b)
1758*da0073e9SAndroid Build Coastguard Worker
1759*da0073e9SAndroid Build Coastguard Worker        v1 = torch.Tensor([10])
1760*da0073e9SAndroid Build Coastguard Worker        v2 = torch.Tensor([20])
1761*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1762*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1763*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(v1, v2).ab, 50)
1764*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1765*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple2(self):
1768*da0073e9SAndroid Build Coastguard Worker        def fn(packed):
1769*da0073e9SAndroid Build Coastguard Worker            a, b, c = packed
1770*da0073e9SAndroid Build Coastguard Worker            if hasattr(packed, "b"):
1771*da0073e9SAndroid Build Coastguard Worker                b = packed.b + 1
1772*da0073e9SAndroid Build Coastguard Worker            c = packed[2]
1773*da0073e9SAndroid Build Coastguard Worker            return a + b + c
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard Worker        v1 = torch.Tensor([1])
1776*da0073e9SAndroid Build Coastguard Worker        v2 = torch.Tensor([2])
1777*da0073e9SAndroid Build Coastguard Worker        v3 = torch.Tensor([3])
1778*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1779*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1780*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7)
1781*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1782*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 3)
1783*da0073e9SAndroid Build Coastguard Worker
1784*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple3(self):
1785*da0073e9SAndroid Build Coastguard Worker        def fn(x, packed):
1786*da0073e9SAndroid Build Coastguard Worker            if isinstance(packed, mytuple):
1787*da0073e9SAndroid Build Coastguard Worker                return x + 1
1788*da0073e9SAndroid Build Coastguard Worker            else:
1789*da0073e9SAndroid Build Coastguard Worker                return x - 1
1790*da0073e9SAndroid Build Coastguard Worker
1791*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([2, 3])
1792*da0073e9SAndroid Build Coastguard Worker        packed = mytuple(1, 2, 3)
1793*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, packed)
1794*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1795*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, packed)
1796*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker    def test_range_input(self):
1799*da0073e9SAndroid Build Coastguard Worker        def fn(a, rng):
1800*da0073e9SAndroid Build Coastguard Worker            x = a
1801*da0073e9SAndroid Build Coastguard Worker            for i in rng:
1802*da0073e9SAndroid Build Coastguard Worker                x = x + i
1803*da0073e9SAndroid Build Coastguard Worker            return x
1804*da0073e9SAndroid Build Coastguard Worker
1805*da0073e9SAndroid Build Coastguard Worker        def fn1(a):
1806*da0073e9SAndroid Build Coastguard Worker            return fn(a, rng=range(3))
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker        return torch._dynamo.testing.standard_test(
1809*da0073e9SAndroid Build Coastguard Worker            self, fn=fn1, nargs=1, expected_ops=3
1810*da0073e9SAndroid Build Coastguard Worker        )
1811*da0073e9SAndroid Build Coastguard Worker
1812*da0073e9SAndroid Build Coastguard Worker    def test_range_with_shape(self):
1813*da0073e9SAndroid Build Coastguard Worker        def fn(a):
1814*da0073e9SAndroid Build Coastguard Worker            for i in range(1, a.shape[0]):
1815*da0073e9SAndroid Build Coastguard Worker                a += 1
1816*da0073e9SAndroid Build Coastguard Worker            return a
1817*da0073e9SAndroid Build Coastguard Worker
1818*da0073e9SAndroid Build Coastguard Worker        return torch._dynamo.testing.standard_test(
1819*da0073e9SAndroid Build Coastguard Worker            self,
1820*da0073e9SAndroid Build Coastguard Worker            fn=fn,
1821*da0073e9SAndroid Build Coastguard Worker            nargs=1,
1822*da0073e9SAndroid Build Coastguard Worker            expected_ops=9,
1823*da0073e9SAndroid Build Coastguard Worker        )
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Worker    def test_build_tuple_unpack(self):
1826*da0073e9SAndroid Build Coastguard Worker        def fn1(a, b, c):
1827*da0073e9SAndroid Build Coastguard Worker            return a - b / c
1828*da0073e9SAndroid Build Coastguard Worker
1829*da0073e9SAndroid Build Coastguard Worker        def fn2(a, b, c):
1830*da0073e9SAndroid Build Coastguard Worker            tmp1 = (a,)
1831*da0073e9SAndroid Build Coastguard Worker            tmp2 = (b, c)
1832*da0073e9SAndroid Build Coastguard Worker            args = (*tmp1, *tmp2)
1833*da0073e9SAndroid Build Coastguard Worker            return fn1(*args)
1834*da0073e9SAndroid Build Coastguard Worker
1835*da0073e9SAndroid Build Coastguard Worker        def fn3(a, *args):
1836*da0073e9SAndroid Build Coastguard Worker            return fn1(a, *args)
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn=fn2, nargs=3, expected_ops=2)
1839*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn=fn3, nargs=3, expected_ops=2)
1840*da0073e9SAndroid Build Coastguard Worker
1841*da0073e9SAndroid Build Coastguard Worker    def test_list_mul(self):
1842*da0073e9SAndroid Build Coastguard Worker        def fn(count):
1843*da0073e9SAndroid Build Coastguard Worker            head_mask = count * [None] * count
1844*da0073e9SAndroid Build Coastguard Worker            return head_mask
1845*da0073e9SAndroid Build Coastguard Worker
1846*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1847*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(2), [None] * 4)
1849*da0073e9SAndroid Build Coastguard Worker        # TODO: the captured frame here is a bit goofy, because we don't
1850*da0073e9SAndroid Build Coastguard Worker        # output anything and none of the traced operations have side
1851*da0073e9SAndroid Build Coastguard Worker        # effects.  Probably need better heuristic for bailing on
1852*da0073e9SAndroid Build Coastguard Worker        # dynamo if there are no outputs
1853*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1854*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.frame_count, """0""")
1855*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """0""")
1856*da0073e9SAndroid Build Coastguard Worker        else:
1857*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.frame_count, """1""")
1858*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """2""")
1859*da0073e9SAndroid Build Coastguard Worker
1860*da0073e9SAndroid Build Coastguard Worker    def test_list_slice_mul(self):
1861*da0073e9SAndroid Build Coastguard Worker        def fn(count):
1862*da0073e9SAndroid Build Coastguard Worker            a = [1, 2, 3]
1863*da0073e9SAndroid Build Coastguard Worker            head_mask = count * a[1:] * count
1864*da0073e9SAndroid Build Coastguard Worker            return head_mask
1865*da0073e9SAndroid Build Coastguard Worker
1866*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1867*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1868*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(2), [2, 3] * 4)
1869*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1870*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.frame_count, """0""")
1871*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """0""")
1872*da0073e9SAndroid Build Coastguard Worker        else:
1873*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.frame_count, """1""")
1874*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """2""")
1875*da0073e9SAndroid Build Coastguard Worker
1876*da0073e9SAndroid Build Coastguard Worker    def test_tuple_mul(self):
1877*da0073e9SAndroid Build Coastguard Worker        def fn(count):
1878*da0073e9SAndroid Build Coastguard Worker            head_mask = count * (2, 3) * count
1879*da0073e9SAndroid Build Coastguard Worker            return head_mask
1880*da0073e9SAndroid Build Coastguard Worker
1881*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1882*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1883*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(2), (2, 3) * 4)
1884*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1885*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.frame_count, """0""")
1886*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """0""")
1887*da0073e9SAndroid Build Coastguard Worker        else:
1888*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.frame_count, """1""")
1889*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """2""")
1890*da0073e9SAndroid Build Coastguard Worker
1891*da0073e9SAndroid Build Coastguard Worker    def test_tuple_mul_with_shape(self):
1892*da0073e9SAndroid Build Coastguard Worker        def fn(a):
1893*da0073e9SAndroid Build Coastguard Worker            x = a.shape[0]
1894*da0073e9SAndroid Build Coastguard Worker            y = 2 * (x, 3) * 2
1895*da0073e9SAndroid Build Coastguard Worker            return a + y[4]
1896*da0073e9SAndroid Build Coastguard Worker
1897*da0073e9SAndroid Build Coastguard Worker        # expect 3 ops post folding for dynamic case: size, index, add
1898*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1899*da0073e9SAndroid Build Coastguard Worker            self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 3)
1900*da0073e9SAndroid Build Coastguard Worker        )
1901*da0073e9SAndroid Build Coastguard Worker
1902*da0073e9SAndroid Build Coastguard Worker    def test_tuple_iadd_with_shape(self):
1903*da0073e9SAndroid Build Coastguard Worker        def fn(a):
1904*da0073e9SAndroid Build Coastguard Worker            output = (a + a.shape[0], a - a.shape[0])
1905*da0073e9SAndroid Build Coastguard Worker            # tuple += tuple
1906*da0073e9SAndroid Build Coastguard Worker            output += (a - a.shape[0], a + a.shape[0])
1907*da0073e9SAndroid Build Coastguard Worker            # tuple += constant tuple
1908*da0073e9SAndroid Build Coastguard Worker            output += (2, 3)
1909*da0073e9SAndroid Build Coastguard Worker            return output
1910*da0073e9SAndroid Build Coastguard Worker
1911*da0073e9SAndroid Build Coastguard Worker        # expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
1912*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1913*da0073e9SAndroid Build Coastguard Worker            self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(4, 12)
1914*da0073e9SAndroid Build Coastguard Worker        )
1915*da0073e9SAndroid Build Coastguard Worker
1916*da0073e9SAndroid Build Coastguard Worker    def test_list_iadd_with_shape(self):
1917*da0073e9SAndroid Build Coastguard Worker        def fn(a):
1918*da0073e9SAndroid Build Coastguard Worker            output = [a + a.shape[0], a - a.shape[0]]
1919*da0073e9SAndroid Build Coastguard Worker            # list += list
1920*da0073e9SAndroid Build Coastguard Worker            output += [a - a.shape[0], a + a.shape[0]]
1921*da0073e9SAndroid Build Coastguard Worker            # list += tuple
1922*da0073e9SAndroid Build Coastguard Worker            output += (a + a.shape[0], a - a.shape[0])
1923*da0073e9SAndroid Build Coastguard Worker            return output
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Worker        # expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
1926*da0073e9SAndroid Build Coastguard Worker
1927*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(
1928*da0073e9SAndroid Build Coastguard Worker            self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(6, 18)
1929*da0073e9SAndroid Build Coastguard Worker        )
1930*da0073e9SAndroid Build Coastguard Worker
1931*da0073e9SAndroid Build Coastguard Worker    def test_list_iadd_side_effect(self):
1932*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
1933*da0073e9SAndroid Build Coastguard Worker            a += [b]
1934*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
1935*da0073e9SAndroid Build Coastguard Worker            return a
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Worker        a = [1, 2, 3]
1938*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(2, 2)
1939*da0073e9SAndroid Build Coastguard Worker
1940*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1941*da0073e9SAndroid Build Coastguard Worker
1942*da0073e9SAndroid Build Coastguard Worker        exp = fn(a, b)
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker        a = [1, 2, 3]
1945*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(2, 2)
1946*da0073e9SAndroid Build Coastguard Worker        act = opt_fn(a, b)
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(exp, act)
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker    def test_user_getattr1(self):
1951*da0073e9SAndroid Build Coastguard Worker        class MyConfig(dict):
1952*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, name):
1953*da0073e9SAndroid Build Coastguard Worker                return self[name]
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker        def fn(cfg, x, y):
1956*da0073e9SAndroid Build Coastguard Worker            return x + y + cfg.offset
1957*da0073e9SAndroid Build Coastguard Worker
1958*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
1959*da0073e9SAndroid Build Coastguard Worker        cfg = MyConfig(offset=5)
1960*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1961*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1962*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5))
1963*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1964*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
1965*da0073e9SAndroid Build Coastguard Worker
1966*da0073e9SAndroid Build Coastguard Worker    def test_user_getattr2(self):
1967*da0073e9SAndroid Build Coastguard Worker        class MyConfig:
1968*da0073e9SAndroid Build Coastguard Worker            defined_on_class = 1
1969*da0073e9SAndroid Build Coastguard Worker
1970*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
1971*da0073e9SAndroid Build Coastguard Worker                self.defined_on_object = 2
1972*da0073e9SAndroid Build Coastguard Worker
1973*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, name):
1974*da0073e9SAndroid Build Coastguard Worker                return 3
1975*da0073e9SAndroid Build Coastguard Worker
1976*da0073e9SAndroid Build Coastguard Worker        def fn(cfg, x):
1977*da0073e9SAndroid Build Coastguard Worker            return x + cfg.defined_on_class - cfg.defined_on_object + cfg.not_defined
1978*da0073e9SAndroid Build Coastguard Worker
1979*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
1980*da0073e9SAndroid Build Coastguard Worker        cfg = MyConfig()
1981*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1982*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
1983*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(cfg, x), x + 1 - 2 + 3))
1984*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
1985*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 3)
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Worker    def test_getset_descriptor(self):
1988*da0073e9SAndroid Build Coastguard Worker        def fn(g, x):
1989*da0073e9SAndroid Build Coastguard Worker            return g.__get__(x)
1990*da0073e9SAndroid Build Coastguard Worker
1991*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
1992*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fullgraph=True, backend="eager")(fn)
1993*da0073e9SAndroid Build Coastguard Worker        g = torch.Tensor.shape
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(g, torch.ones(2, 2))
1996*da0073e9SAndroid Build Coastguard Worker        exp_res = fn(g, torch.ones(2, 2))
1997*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, exp_res)
1998*da0073e9SAndroid Build Coastguard Worker
1999*da0073e9SAndroid Build Coastguard Worker    def test_get_attr_function(self):
2000*da0073e9SAndroid Build Coastguard Worker        def fn(g, x):
2001*da0073e9SAndroid Build Coastguard Worker            return g(x)
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2004*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2005*da0073e9SAndroid Build Coastguard Worker        g = torch.Tensor.shape.__get__
2006*da0073e9SAndroid Build Coastguard Worker
2007*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(g, torch.ones(2, 2))
2008*da0073e9SAndroid Build Coastguard Worker        exp_res = fn(g, torch.ones(2, 2))
2009*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, exp_res)
2010*da0073e9SAndroid Build Coastguard Worker
2011*da0073e9SAndroid Build Coastguard Worker    def test_user_getattribute(self):
2012*da0073e9SAndroid Build Coastguard Worker        class MyObject:
2013*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
2014*da0073e9SAndroid Build Coastguard Worker                self.custom_dict = {"a": torch.rand((2, 2))}
2015*da0073e9SAndroid Build Coastguard Worker                self.my_number = 42
2016*da0073e9SAndroid Build Coastguard Worker
2017*da0073e9SAndroid Build Coastguard Worker            def __getattribute__(self, name):
2018*da0073e9SAndroid Build Coastguard Worker                custom_dict = super().__getattribute__("custom_dict")
2019*da0073e9SAndroid Build Coastguard Worker                if name in custom_dict:
2020*da0073e9SAndroid Build Coastguard Worker                    return custom_dict[name]
2021*da0073e9SAndroid Build Coastguard Worker                return super().__getattribute__(name)
2022*da0073e9SAndroid Build Coastguard Worker
2023*da0073e9SAndroid Build Coastguard Worker            def run(self, x):
2024*da0073e9SAndroid Build Coastguard Worker                return self.my_number * x + self.a * x
2025*da0073e9SAndroid Build Coastguard Worker
2026*da0073e9SAndroid Build Coastguard Worker        def fn(obj, x):
2027*da0073e9SAndroid Build Coastguard Worker            return obj.run(x)
2028*da0073e9SAndroid Build Coastguard Worker
2029*da0073e9SAndroid Build Coastguard Worker        obj = MyObject()
2030*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
2031*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2032*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2033*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(obj, x), fn(obj, x)))
2034*da0073e9SAndroid Build Coastguard Worker
2035*da0073e9SAndroid Build Coastguard Worker    def test_nn_module_getattr(self):
2036*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
2037*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
2038*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2039*da0073e9SAndroid Build Coastguard Worker                self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
2040*da0073e9SAndroid Build Coastguard Worker                self.other_attr = torch.rand((2, 2))
2041*da0073e9SAndroid Build Coastguard Worker
2042*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, name):
2043*da0073e9SAndroid Build Coastguard Worker                custom_dict = self.custom_dict
2044*da0073e9SAndroid Build Coastguard Worker                if name in custom_dict:
2045*da0073e9SAndroid Build Coastguard Worker                    return custom_dict[name]
2046*da0073e9SAndroid Build Coastguard Worker                return super().__getattr__(name)
2047*da0073e9SAndroid Build Coastguard Worker
2048*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2049*da0073e9SAndroid Build Coastguard Worker                return x @ self.other_attr + self.queue[-1]
2050*da0073e9SAndroid Build Coastguard Worker
2051*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
2052*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
2053*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2054*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize(cnts)(mod)
2055*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_mod(x), mod(x)))
2056*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cnts.frame_count, 1)
2057*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cnts.op_count, 2)
2058*da0073e9SAndroid Build Coastguard Worker
2059*da0073e9SAndroid Build Coastguard Worker    def test_nn_module_getattribute(self):
2060*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
2061*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
2062*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2063*da0073e9SAndroid Build Coastguard Worker                self.my_number = 42
2064*da0073e9SAndroid Build Coastguard Worker
2065*da0073e9SAndroid Build Coastguard Worker            def __getattribute__(self, name):
2066*da0073e9SAndroid Build Coastguard Worker                if name == "special_attr":
2067*da0073e9SAndroid Build Coastguard Worker                    return torch.tensor([[1, 2], [3, 4]])
2068*da0073e9SAndroid Build Coastguard Worker                return super().__getattribute__(name)
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2071*da0073e9SAndroid Build Coastguard Worker                return self.my_number * x + self.special_attr * x
2072*da0073e9SAndroid Build Coastguard Worker
2073*da0073e9SAndroid Build Coastguard Worker        def fn(mod, x):
2074*da0073e9SAndroid Build Coastguard Worker            return mod(x)
2075*da0073e9SAndroid Build Coastguard Worker
2076*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
2077*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
2078*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2079*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2080*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(mod, x), fn(mod, x)))
2081*da0073e9SAndroid Build Coastguard Worker
2082*da0073e9SAndroid Build Coastguard Worker    def test_constant_getattr(self):
2083*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/97480
2084*da0073e9SAndroid Build Coastguard Worker        def fn():
2085*da0073e9SAndroid Build Coastguard Worker            return getattr(None, "arg", 3)
2086*da0073e9SAndroid Build Coastguard Worker
2087*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2088*da0073e9SAndroid Build Coastguard Worker        optimized_fn = torch._dynamo.optimize(cnt)(fn)
2089*da0073e9SAndroid Build Coastguard Worker        res = optimized_fn()
2090*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res, 3))
2091*da0073e9SAndroid Build Coastguard Worker
2092*da0073e9SAndroid Build Coastguard Worker    def test_user_property(self):
2093*da0073e9SAndroid Build Coastguard Worker        class MyConfig:
2094*da0073e9SAndroid Build Coastguard Worker            @property
2095*da0073e9SAndroid Build Coastguard Worker            def prop5(self):
2096*da0073e9SAndroid Build Coastguard Worker                return 5
2097*da0073e9SAndroid Build Coastguard Worker
2098*da0073e9SAndroid Build Coastguard Worker        def fn(cfg, x, y):
2099*da0073e9SAndroid Build Coastguard Worker            return x + y + cfg.prop5
2100*da0073e9SAndroid Build Coastguard Worker
2101*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
2102*da0073e9SAndroid Build Coastguard Worker        cfg = MyConfig()
2103*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2104*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2105*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5))
2106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
2108*da0073e9SAndroid Build Coastguard Worker
2109*da0073e9SAndroid Build Coastguard Worker    def test_dataclass_fields(self):
2110*da0073e9SAndroid Build Coastguard Worker        @dataclasses.dataclass
2111*da0073e9SAndroid Build Coastguard Worker        class MyDataClass:
2112*da0073e9SAndroid Build Coastguard Worker            a: torch.Tensor
2113*da0073e9SAndroid Build Coastguard Worker            b: torch.Tensor = None
2114*da0073e9SAndroid Build Coastguard Worker            c: torch.Tensor = None
2115*da0073e9SAndroid Build Coastguard Worker            d: torch.Tensor = None
2116*da0073e9SAndroid Build Coastguard Worker            e: torch.Tensor = None
2117*da0073e9SAndroid Build Coastguard Worker
2118*da0073e9SAndroid Build Coastguard Worker        def fn(obj):
2119*da0073e9SAndroid Build Coastguard Worker            class_fields = dataclasses.fields(obj)
2120*da0073e9SAndroid Build Coastguard Worker            assert len(class_fields)
2121*da0073e9SAndroid Build Coastguard Worker            assert all(field.default is None for field in class_fields[1:])
2122*da0073e9SAndroid Build Coastguard Worker            other_fields_are_none = all(
2123*da0073e9SAndroid Build Coastguard Worker                getattr(obj, field.name) is None for field in class_fields[1:]
2124*da0073e9SAndroid Build Coastguard Worker            )
2125*da0073e9SAndroid Build Coastguard Worker            assert not other_fields_are_none
2126*da0073e9SAndroid Build Coastguard Worker
2127*da0073e9SAndroid Build Coastguard Worker            if not hasattr(obj, "a"):
2128*da0073e9SAndroid Build Coastguard Worker                return -1
2129*da0073e9SAndroid Build Coastguard Worker            if hasattr(obj, "z"):
2130*da0073e9SAndroid Build Coastguard Worker                return -2
2131*da0073e9SAndroid Build Coastguard Worker
2132*da0073e9SAndroid Build Coastguard Worker            total = getattr(obj, class_fields[0].name)
2133*da0073e9SAndroid Build Coastguard Worker            for field in class_fields[1:]:
2134*da0073e9SAndroid Build Coastguard Worker                v = getattr(obj, field.name)
2135*da0073e9SAndroid Build Coastguard Worker                if v is not None:
2136*da0073e9SAndroid Build Coastguard Worker                    total += v
2137*da0073e9SAndroid Build Coastguard Worker
2138*da0073e9SAndroid Build Coastguard Worker            return total
2139*da0073e9SAndroid Build Coastguard Worker
2140*da0073e9SAndroid Build Coastguard Worker        obj1 = MyDataClass(torch.randn(10), torch.randn(10), torch.randn(10))
2141*da0073e9SAndroid Build Coastguard Worker        obj2 = MyDataClass(torch.randn(10), e=torch.randn(10))
2142*da0073e9SAndroid Build Coastguard Worker        correct1 = fn(obj1)
2143*da0073e9SAndroid Build Coastguard Worker        correct2 = fn(obj2)
2144*da0073e9SAndroid Build Coastguard Worker
2145*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2146*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2147*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(obj1), correct1))
2148*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
2150*da0073e9SAndroid Build Coastguard Worker
2151*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
2152*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2153*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2154*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(obj2), correct2))
2155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 1)
2157*da0073e9SAndroid Build Coastguard Worker
2158*da0073e9SAndroid Build Coastguard Worker        # guard failure
2159*da0073e9SAndroid Build Coastguard Worker        obj2.z = True
2160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(obj2), -2)
2161*da0073e9SAndroid Build Coastguard Worker
2162*da0073e9SAndroid Build Coastguard Worker    def test_dataclass_local_hasattr(self):
2163*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
2164*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
2165*da0073e9SAndroid Build Coastguard Worker
2166*da0073e9SAndroid Build Coastguard Worker        @dataclasses.dataclass
2167*da0073e9SAndroid Build Coastguard Worker        class MyDataClass:
2168*da0073e9SAndroid Build Coastguard Worker            a: torch.Tensor
2169*da0073e9SAndroid Build Coastguard Worker            b: torch.Tensor
2170*da0073e9SAndroid Build Coastguard Worker
2171*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
2172*da0073e9SAndroid Build Coastguard Worker        def fn():
2173*da0073e9SAndroid Build Coastguard Worker            obj = MyDataClass(x + 1, x - 1)
2174*da0073e9SAndroid Build Coastguard Worker            if not hasattr(obj, "a"):
2175*da0073e9SAndroid Build Coastguard Worker                return -1
2176*da0073e9SAndroid Build Coastguard Worker            if hasattr(obj, "z"):
2177*da0073e9SAndroid Build Coastguard Worker                return -2
2178*da0073e9SAndroid Build Coastguard Worker            return obj
2179*da0073e9SAndroid Build Coastguard Worker
2180*da0073e9SAndroid Build Coastguard Worker        result = fn()
2181*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(result, MyDataClass)
2182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.a, x + 1)
2183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.b, x - 1)
2184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 2)
2186*da0073e9SAndroid Build Coastguard Worker
2187*da0073e9SAndroid Build Coastguard Worker    def test_catch_watchings1(self):
2188*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
2189*da0073e9SAndroid Build Coastguard Worker
2190*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
2191*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2192*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True):
2193*da0073e9SAndroid Build Coastguard Worker                return x.sin()
2194*da0073e9SAndroid Build Coastguard Worker
2195*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(8)
2196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), x.sin())
2197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2198*da0073e9SAndroid Build Coastguard Worker
2199*da0073e9SAndroid Build Coastguard Worker    def test_catch_watchings2(self):
2200*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
2201*da0073e9SAndroid Build Coastguard Worker
2202*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
2203*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2204*da0073e9SAndroid Build Coastguard Worker            return x.sin(), warnings.catch_warnings(record=True)
2205*da0073e9SAndroid Build Coastguard Worker
2206*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(8)
2207*da0073e9SAndroid Build Coastguard Worker        _, a = fn(x)
2208*da0073e9SAndroid Build Coastguard Worker        _, b = fn(x)
2209*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2210*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(a, warnings.catch_warnings)
2211*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(b, warnings.catch_warnings)
2212*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(a, b)
2213*da0073e9SAndroid Build Coastguard Worker
2214*da0073e9SAndroid Build Coastguard Worker    def test_tensor_build_list_unpack(self):
2215*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2216*da0073e9SAndroid Build Coastguard Worker            # seen in fastNLP_Bert
2217*da0073e9SAndroid Build Coastguard Worker            return torch.cat([*x], dim=-1)
2218*da0073e9SAndroid Build Coastguard Worker
2219*da0073e9SAndroid Build Coastguard Worker        val = torch.randn([1, 1, 473, 768])
2220*da0073e9SAndroid Build Coastguard Worker        correct = fn(val)
2221*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2222*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2223*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(val), correct))
2224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2225*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
2226*da0073e9SAndroid Build Coastguard Worker
2227*da0073e9SAndroid Build Coastguard Worker    def test_numpy_int_constant(self):
2228*da0073e9SAndroid Build Coastguard Worker        def fn(x, a, b):
2229*da0073e9SAndroid Build Coastguard Worker            return x + (a % b)
2230*da0073e9SAndroid Build Coastguard Worker
2231*da0073e9SAndroid Build Coastguard Worker        args = [torch.randn(10), 4096, np.int64(8)]
2232*da0073e9SAndroid Build Coastguard Worker        correct = fn(*args)
2233*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2234*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, dynamic=True, nopython=True)(fn)
2235*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(*args), correct))
2236*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(*args), correct))
2237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
2239*da0073e9SAndroid Build Coastguard Worker
2240*da0073e9SAndroid Build Coastguard Worker    def test_numpy_subdtype(self):
2241*da0073e9SAndroid Build Coastguard Worker        def fn(x, n):
2242*da0073e9SAndroid Build Coastguard Worker            return np.issubdtype(type(n), np.integer) + x
2243*da0073e9SAndroid Build Coastguard Worker
2244*da0073e9SAndroid Build Coastguard Worker        args = [torch.randn(10), 4096]
2245*da0073e9SAndroid Build Coastguard Worker        correct = fn(*args)
2246*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2247*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2248*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(*args), correct)
2249*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2250*da0073e9SAndroid Build Coastguard Worker
2251*da0073e9SAndroid Build Coastguard Worker    def test_numpy_take_along_axis(self):
2252*da0073e9SAndroid Build Coastguard Worker        def fn(x, i, a):
2253*da0073e9SAndroid Build Coastguard Worker            return np.take_along_axis(x, i, a)
2254*da0073e9SAndroid Build Coastguard Worker
2255*da0073e9SAndroid Build Coastguard Worker        def sample_to_args(s):
2256*da0073e9SAndroid Build Coastguard Worker            args = (s.input, *sample.args)
2257*da0073e9SAndroid Build Coastguard Worker            return tuple(a.numpy() if isinstance(a, torch.Tensor) else a for a in args)
2258*da0073e9SAndroid Build Coastguard Worker
2259*da0073e9SAndroid Build Coastguard Worker        samples = list(
2260*da0073e9SAndroid Build Coastguard Worker            sample_inputs_take_along_dim(
2261*da0073e9SAndroid Build Coastguard Worker                None, "cpu", torch.float32, requires_grad=False
2262*da0073e9SAndroid Build Coastguard Worker            )
2263*da0073e9SAndroid Build Coastguard Worker        )
2264*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2265*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2266*da0073e9SAndroid Build Coastguard Worker        i = 1
2267*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
2268*da0073e9SAndroid Build Coastguard Worker            args = sample_to_args(sample)
2269*da0073e9SAndroid Build Coastguard Worker            if len(args) < 3:
2270*da0073e9SAndroid Build Coastguard Worker                # if axis is None, second argument is treated as 1d array
2271*da0073e9SAndroid Build Coastguard Worker                args = (args[0], np.ravel(args[1]), None)
2272*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn(*args), opt_fn(*args))
2273*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, i)
2274*da0073e9SAndroid Build Coastguard Worker            i += 1
2275*da0073e9SAndroid Build Coastguard Worker
2276*da0073e9SAndroid Build Coastguard Worker    def test_numpy_torch_operators(self):
2277*da0073e9SAndroid Build Coastguard Worker        def fn(op, t1, t2):
2278*da0073e9SAndroid Build Coastguard Worker            return op(t1, t2)
2279*da0073e9SAndroid Build Coastguard Worker
2280*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.variables.builtin import BuiltinVariable
2281*da0073e9SAndroid Build Coastguard Worker
2282*da0073e9SAndroid Build Coastguard Worker        operators = BuiltinVariable._fx_graph_functions()
2283*da0073e9SAndroid Build Coastguard Worker
2284*da0073e9SAndroid Build Coastguard Worker        for op, t1_np, t2_np in itertools.product(
2285*da0073e9SAndroid Build Coastguard Worker            operators, (True, False), (True, False)
2286*da0073e9SAndroid Build Coastguard Worker        ):
2287*da0073e9SAndroid Build Coastguard Worker            if op in [operator.eq, operator.ne]:
2288*da0073e9SAndroid Build Coastguard Worker                # returns equivalent of torch.eq/ne
2289*da0073e9SAndroid Build Coastguard Worker                continue
2290*da0073e9SAndroid Build Coastguard Worker            if op is operator.getitem:
2291*da0073e9SAndroid Build Coastguard Worker                # skip
2292*da0073e9SAndroid Build Coastguard Worker                # Did you know that tensor[ndarray_of_floats] works?
2293*da0073e9SAndroid Build Coastguard Worker                continue
2294*da0073e9SAndroid Build Coastguard Worker            if op is operator.imatmul and (t1_np or t2_np):
2295*da0073e9SAndroid Build Coastguard Worker                # skip
2296*da0073e9SAndroid Build Coastguard Worker                # in numpy, in place matmul does not work single
2297*da0073e9SAndroid Build Coastguard Worker                # dimensional arrays
2298*da0073e9SAndroid Build Coastguard Worker                continue
2299*da0073e9SAndroid Build Coastguard Worker            t1 = torch.rand(5)
2300*da0073e9SAndroid Build Coastguard Worker            if t1_np:
2301*da0073e9SAndroid Build Coastguard Worker                t1 = t1.numpy()
2302*da0073e9SAndroid Build Coastguard Worker            t2 = torch.rand(5)
2303*da0073e9SAndroid Build Coastguard Worker            if t2_np:
2304*da0073e9SAndroid Build Coastguard Worker                t2 = t2.numpy()
2305*da0073e9SAndroid Build Coastguard Worker            try:
2306*da0073e9SAndroid Build Coastguard Worker                # TODO try a bit harder
2307*da0073e9SAndroid Build Coastguard Worker                result = op(t1, t2)
2308*da0073e9SAndroid Build Coastguard Worker            except (RuntimeError, TypeError, IndexError):
2309*da0073e9SAndroid Build Coastguard Worker                continue
2310*da0073e9SAndroid Build Coastguard Worker            cnts = torch._dynamo.testing.CompileCounter()
2311*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize(cnts)(fn)
2312*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, opt_fn(op, t1, t2), msg=f"{op=} {t1_np=} {t2_np=}")
2313*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 1, msg=f"{op=} {t1_np=} {t2_np=}")
2314*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
2315*da0073e9SAndroid Build Coastguard Worker
2316*da0073e9SAndroid Build Coastguard Worker    def test_numpy_ndarray_graph_break(self):
2317*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2318*da0073e9SAndroid Build Coastguard Worker            a = x.numpy()
2319*da0073e9SAndroid Build Coastguard Worker            b = a.real
2320*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
2321*da0073e9SAndroid Build Coastguard Worker            c = np.multiply(b, 2.0)
2322*da0073e9SAndroid Build Coastguard Worker            return c
2323*da0073e9SAndroid Build Coastguard Worker
2324*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2325*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2326*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
2327*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
2328*da0073e9SAndroid Build Coastguard Worker            ref = fn(x)
2329*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x)
2330*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, res)
2331*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
2332*da0073e9SAndroid Build Coastguard Worker
2333*da0073e9SAndroid Build Coastguard Worker    def test_numpy_ndarray_graph_break_with_multiple_outputs(self):
2334*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
2335*da0073e9SAndroid Build Coastguard Worker            a = x.numpy()
2336*da0073e9SAndroid Build Coastguard Worker            b = y.numpy()
2337*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
2338*da0073e9SAndroid Build Coastguard Worker            return np.add(a, 1), np.add(b, 1)
2339*da0073e9SAndroid Build Coastguard Worker
2340*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2341*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2342*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
2343*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([1, 3])
2344*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([1, 3])
2345*da0073e9SAndroid Build Coastguard Worker            ref = fn(x, y)
2346*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x, y)
2347*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, res)
2348*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
2349*da0073e9SAndroid Build Coastguard Worker
2350*da0073e9SAndroid Build Coastguard Worker    def test_numpy_force(self):
2351*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2352*da0073e9SAndroid Build Coastguard Worker            return x.numpy(force=False)
2353*da0073e9SAndroid Build Coastguard Worker
2354*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2355*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2356*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2357*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
2358*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(res), np.ndarray)
2359*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2360*da0073e9SAndroid Build Coastguard Worker
2361*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2362*da0073e9SAndroid Build Coastguard Worker            return x.numpy(force=True)
2363*da0073e9SAndroid Build Coastguard Worker
2364*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2365*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2366*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
2367*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
2368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(res), np.ndarray)
2369*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker    def test_numpy_recompilation_scalar(self):
2372*da0073e9SAndroid Build Coastguard Worker        def fn(x, a):
2373*da0073e9SAndroid Build Coastguard Worker            return np.where(x < 0.5, a, x)
2374*da0073e9SAndroid Build Coastguard Worker
2375*da0073e9SAndroid Build Coastguard Worker        x = np.random.randn(8)
2376*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2377*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn)
2378*da0073e9SAndroid Build Coastguard Worker
2379*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, 3)
2380*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, 3)
2381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, 4)
2384*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, 4)
2385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
2386*da0073e9SAndroid Build Coastguard Worker
2387*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker    def test_tensor_interacts_with_numpy_ndarray(self):
2390*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
2391*da0073e9SAndroid Build Coastguard Worker            a = x.numpy()
2392*da0073e9SAndroid Build Coastguard Worker            b = y.numpy()
2393*da0073e9SAndroid Build Coastguard Worker            c = np.ones_like(a)
2394*da0073e9SAndroid Build Coastguard Worker            d = np.ones_like(b)
2395*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
2396*da0073e9SAndroid Build Coastguard Worker            return np.add(a, c), np.add(b, d)
2397*da0073e9SAndroid Build Coastguard Worker
2398*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2399*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2400*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
2401*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([1, 3])
2402*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([1, 3])
2403*da0073e9SAndroid Build Coastguard Worker            ref = fn(x, y)
2404*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x, y)
2405*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, res)
2406*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
2407*da0073e9SAndroid Build Coastguard Worker
2408*da0073e9SAndroid Build Coastguard Worker    def test_numpy_ndarray_works_with_builtin_function(self):
2409*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2410*da0073e9SAndroid Build Coastguard Worker            v = x.sum() / len(x)
2411*da0073e9SAndroid Build Coastguard Worker            return v
2412*da0073e9SAndroid Build Coastguard Worker
2413*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2414*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2415*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
2416*da0073e9SAndroid Build Coastguard Worker            x = np.random.randn(2, 3)
2417*da0073e9SAndroid Build Coastguard Worker            ref = fn(x)
2418*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x)
2419*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, res)
2420*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2421*da0073e9SAndroid Build Coastguard Worker
2422*da0073e9SAndroid Build Coastguard Worker    def test_numpy_array_of_arrays(self):
2423*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
2424*da0073e9SAndroid Build Coastguard Worker            return np.array([x, y])
2425*da0073e9SAndroid Build Coastguard Worker
2426*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2427*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2428*da0073e9SAndroid Build Coastguard Worker
2429*da0073e9SAndroid Build Coastguard Worker        x, y = np.float64(1), np.float64(2)
2430*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
2431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, np.array([1, 2], dtype=float))
2432*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(res), np.ndarray)
2433*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker        x, y = np.arange(2), np.arange(2) + 2
2436*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
2437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, np.array([[0, 1], [2, 3]]))
2438*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(res), np.ndarray)
2439*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
2440*da0073e9SAndroid Build Coastguard Worker
2441*da0073e9SAndroid Build Coastguard Worker    def test_numpy_readonly(self):
2442*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
2443*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2444*da0073e9SAndroid Build Coastguard Worker            return x
2445*da0073e9SAndroid Build Coastguard Worker
2446*da0073e9SAndroid Build Coastguard Worker        x = np.broadcast_to(np.arange(3), (2, 3))
2447*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.flags.writeable)
2448*da0073e9SAndroid Build Coastguard Worker
2449*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings():
2450*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("error")
2451*da0073e9SAndroid Build Coastguard Worker            y = fn(x)
2452*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.flags.writeable)  # XXX: differs from numpy
2453*da0073e9SAndroid Build Coastguard Worker
2454*da0073e9SAndroid Build Coastguard Worker    def test_numpy_tolist(self):
2455*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2456*da0073e9SAndroid Build Coastguard Worker            return x.tolist()
2457*da0073e9SAndroid Build Coastguard Worker
2458*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2459*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2460*da0073e9SAndroid Build Coastguard Worker
2461*da0073e9SAndroid Build Coastguard Worker        x = np.arange(5)
2462*da0073e9SAndroid Build Coastguard Worker        r = opt_fn(x)
2463*da0073e9SAndroid Build Coastguard Worker
2464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, [0, 1, 2, 3, 4])
2465*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(r), list)
2466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2467*da0073e9SAndroid Build Coastguard Worker
2468*da0073e9SAndroid Build Coastguard Worker    def test_numpy_size_attr(self):
2469*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2470*da0073e9SAndroid Build Coastguard Worker            return x.size + x
2471*da0073e9SAndroid Build Coastguard Worker
2472*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2473*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2474*da0073e9SAndroid Build Coastguard Worker
2475*da0073e9SAndroid Build Coastguard Worker        x = np.arange(5)
2476*da0073e9SAndroid Build Coastguard Worker        r = opt_fn(x)
2477*da0073e9SAndroid Build Coastguard Worker
2478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, fn(x))
2479*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(r), np.ndarray)
2480*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2481*da0073e9SAndroid Build Coastguard Worker
2482*da0073e9SAndroid Build Coastguard Worker    def test_numpy_no_raise(self):
2483*da0073e9SAndroid Build Coastguard Worker        def _inf_nan_preprocess(t, t_np):
2484*da0073e9SAndroid Build Coastguard Worker            t_np = np.nan_to_num(t_np)
2485*da0073e9SAndroid Build Coastguard Worker            return t, t_np
2486*da0073e9SAndroid Build Coastguard Worker
2487*da0073e9SAndroid Build Coastguard Worker        def fn():
2488*da0073e9SAndroid Build Coastguard Worker            # shape, dims format
2489*da0073e9SAndroid Build Coastguard Worker            test_cases = (
2490*da0073e9SAndroid Build Coastguard Worker                (3, 3),
2491*da0073e9SAndroid Build Coastguard Worker                (4, 4),
2492*da0073e9SAndroid Build Coastguard Worker                (5, 5),
2493*da0073e9SAndroid Build Coastguard Worker            )
2494*da0073e9SAndroid Build Coastguard Worker
2495*da0073e9SAndroid Build Coastguard Worker            for shape in test_cases:
2496*da0073e9SAndroid Build Coastguard Worker                t = torch.randn(shape, dtype=torch.complex64)
2497*da0073e9SAndroid Build Coastguard Worker                t_np = np.random.randn(*shape).astype(np.complex64)
2498*da0073e9SAndroid Build Coastguard Worker
2499*da0073e9SAndroid Build Coastguard Worker                _, t_np = _inf_nan_preprocess(t, t_np)
2500*da0073e9SAndroid Build Coastguard Worker                print(t, t_np)  # Just a side effect so that compilation kicks in
2501*da0073e9SAndroid Build Coastguard Worker
2502*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounterWithBackend("inductor")
2503*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnt)(fn)
2504*da0073e9SAndroid Build Coastguard Worker        fn()
2505*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))
2506*da0073e9SAndroid Build Coastguard Worker
2507*da0073e9SAndroid Build Coastguard Worker    def test_mandelbrot_numpy(self):
2508*da0073e9SAndroid Build Coastguard Worker        def mandelbrot_numpy(max_iter):
2509*da0073e9SAndroid Build Coastguard Worker            # Define the boundaries of the complex plane
2510*da0073e9SAndroid Build Coastguard Worker            xn = 450
2511*da0073e9SAndroid Build Coastguard Worker            yn = 375
2512*da0073e9SAndroid Build Coastguard Worker            xmin = -2.25
2513*da0073e9SAndroid Build Coastguard Worker            xmax = 0.75
2514*da0073e9SAndroid Build Coastguard Worker            ymin = -1.25
2515*da0073e9SAndroid Build Coastguard Worker            ymax = 1.25
2516*da0073e9SAndroid Build Coastguard Worker
2517*da0073e9SAndroid Build Coastguard Worker            # Create the grid of complex numbers
2518*da0073e9SAndroid Build Coastguard Worker            x_values = np.linspace(xmin, xmax, xn, dtype=np.float64)
2519*da0073e9SAndroid Build Coastguard Worker            y_values = np.linspace(ymin, ymax, yn, dtype=np.float64)
2520*da0073e9SAndroid Build Coastguard Worker            rx, iy = np.meshgrid(x_values, y_values, indexing="xy")
2521*da0073e9SAndroid Build Coastguard Worker
2522*da0073e9SAndroid Build Coastguard Worker            x = rx.copy()
2523*da0073e9SAndroid Build Coastguard Worker            y = iy.copy()
2524*da0073e9SAndroid Build Coastguard Worker            mask = np.zeros_like(x)
2525*da0073e9SAndroid Build Coastguard Worker            for i in range(max_iter):
2526*da0073e9SAndroid Build Coastguard Worker                x_prev = x
2527*da0073e9SAndroid Build Coastguard Worker                y_prev = y
2528*da0073e9SAndroid Build Coastguard Worker                x = x_prev**2 - y_prev**2 + rx
2529*da0073e9SAndroid Build Coastguard Worker                y = 2 * x_prev * y_prev + iy
2530*da0073e9SAndroid Build Coastguard Worker                inside = np.sqrt(x**2 + y**2) <= 2
2531*da0073e9SAndroid Build Coastguard Worker                mask += inside
2532*da0073e9SAndroid Build Coastguard Worker            return mask
2533*da0073e9SAndroid Build Coastguard Worker
2534*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2535*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(mandelbrot_numpy)
2536*da0073e9SAndroid Build Coastguard Worker        n_iter = torch._dynamo.config.cache_size_limit - 2
2537*da0073e9SAndroid Build Coastguard Worker        for i in range(n_iter):
2538*da0073e9SAndroid Build Coastguard Worker            x = i + 3
2539*da0073e9SAndroid Build Coastguard Worker            ref = mandelbrot_numpy(x)
2540*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x)
2541*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, res)
2542*da0073e9SAndroid Build Coastguard Worker        # We need to specialise the number as it's in a forloop
2543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, n_iter)
2544*da0073e9SAndroid Build Coastguard Worker
2545*da0073e9SAndroid Build Coastguard Worker    def test_numpy_as_global(self):
2546*da0073e9SAndroid Build Coastguard Worker        global x
2547*da0073e9SAndroid Build Coastguard Worker        x = np.arange(10)
2548*da0073e9SAndroid Build Coastguard Worker
2549*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
2550*da0073e9SAndroid Build Coastguard Worker        def fn(y):
2551*da0073e9SAndroid Build Coastguard Worker            return y + x + x
2552*da0073e9SAndroid Build Coastguard Worker
2553*da0073e9SAndroid Build Coastguard Worker        r = fn(np.arange(10))
2554*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(r), np.ndarray)
2555*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, x * 3)
2556*da0073e9SAndroid Build Coastguard Worker        del x
2557*da0073e9SAndroid Build Coastguard Worker
2558*da0073e9SAndroid Build Coastguard Worker    def test_numpy_gt(self):
2559*da0073e9SAndroid Build Coastguard Worker        x = np.arange(10)
2560*da0073e9SAndroid Build Coastguard Worker
2561*da0073e9SAndroid Build Coastguard Worker        @torch.compile
2562*da0073e9SAndroid Build Coastguard Worker        def fn(y):
2563*da0073e9SAndroid Build Coastguard Worker            return y >= 3
2564*da0073e9SAndroid Build Coastguard Worker
2565*da0073e9SAndroid Build Coastguard Worker        r = fn(x)
2566*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(r), np.ndarray)
2567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, x >= 3)
2568*da0073e9SAndroid Build Coastguard Worker
2569*da0073e9SAndroid Build Coastguard Worker    def test_numpy_min(self):
2570*da0073e9SAndroid Build Coastguard Worker        x = np.arange(10)
2571*da0073e9SAndroid Build Coastguard Worker
2572*da0073e9SAndroid Build Coastguard Worker        @torch.compile
2573*da0073e9SAndroid Build Coastguard Worker        def fn(y):
2574*da0073e9SAndroid Build Coastguard Worker            return min(y, 3), min(y, y - 1)
2575*da0073e9SAndroid Build Coastguard Worker
2576*da0073e9SAndroid Build Coastguard Worker        r1, r2 = fn(x)
2577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(r1), np.ndarray)
2578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(r2), np.ndarray)
2579*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1, np.minimum(x, 3))
2580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r2, np.minimum(x, x - 1))
2581*da0073e9SAndroid Build Coastguard Worker
2582*da0073e9SAndroid Build Coastguard Worker    def test_graph_break_correctly_when_passing_numpy_ndarray_to_torch_function(self):
2583*da0073e9SAndroid Build Coastguard Worker        # from transformers/models/big_bird/modeling_big_bird.py
2584*da0073e9SAndroid Build Coastguard Worker        def fn(x: int, y: torch.Tensor):
2585*da0073e9SAndroid Build Coastguard Worker            ndarray_list = [np.ones([2, x])]
2586*da0073e9SAndroid Build Coastguard Worker            ndarray = np.stack(ndarray_list, axis=0)
2587*da0073e9SAndroid Build Coastguard Worker            tensor = torch.tensor(ndarray, dtype=torch.long)
2588*da0073e9SAndroid Build Coastguard Worker            tensor.unsqueeze_(0)
2589*da0073e9SAndroid Build Coastguard Worker            return tensor + y
2590*da0073e9SAndroid Build Coastguard Worker
2591*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2592*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2593*da0073e9SAndroid Build Coastguard Worker        for x in range(1, 10):
2594*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([1, 2, x])
2595*da0073e9SAndroid Build Coastguard Worker            ref = fn(x, y)
2596*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x, y)
2597*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, res)
2598*da0073e9SAndroid Build Coastguard Worker        # It's all traced once with x = 1, x = 2 and then x = ks0
2599*da0073e9SAndroid Build Coastguard Worker        # For dynamic it's x=1 and x=ks0
2600*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, ifdynstaticdefault(3, 2))
2601*da0073e9SAndroid Build Coastguard Worker
2602*da0073e9SAndroid Build Coastguard Worker    def test_numpy_with_builtin_type(self):
2603*da0073e9SAndroid Build Coastguard Worker        x = np.random.rand(5)
2604*da0073e9SAndroid Build Coastguard Worker
2605*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2606*da0073e9SAndroid Build Coastguard Worker            return (x * 5).astype(bool).astype(float).astype(int) + 8
2607*da0073e9SAndroid Build Coastguard Worker
2608*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2609*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2610*da0073e9SAndroid Build Coastguard Worker
2611*da0073e9SAndroid Build Coastguard Worker        r = opt_fn(x)
2612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.dtype, int)
2613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2614*da0073e9SAndroid Build Coastguard Worker
2615*da0073e9SAndroid Build Coastguard Worker    def test_with_builtin_type(self):
2616*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5)
2617*da0073e9SAndroid Build Coastguard Worker
2618*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2619*da0073e9SAndroid Build Coastguard Worker            return (x * 5).to(bool).to(float).to(int) + 8
2620*da0073e9SAndroid Build Coastguard Worker
2621*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2622*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2623*da0073e9SAndroid Build Coastguard Worker
2624*da0073e9SAndroid Build Coastguard Worker        r = opt_fn(x)
2625*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.dtype, torch.int64)
2626*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2627*da0073e9SAndroid Build Coastguard Worker
2628*da0073e9SAndroid Build Coastguard Worker    def test_numpy_unique_f16(self):
2629*da0073e9SAndroid Build Coastguard Worker        def fn():
2630*da0073e9SAndroid Build Coastguard Worker            x = np.asarray([1, 1, 2, 2, 3], dtype=np.float16)
2631*da0073e9SAndroid Build Coastguard Worker            return np.unique(x)
2632*da0073e9SAndroid Build Coastguard Worker
2633*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2634*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2635*da0073e9SAndroid Build Coastguard Worker
2636*da0073e9SAndroid Build Coastguard Worker        r = opt_fn()
2637*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.dtype, np.float16)
2638*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2639*da0073e9SAndroid Build Coastguard Worker
2640*da0073e9SAndroid Build Coastguard Worker    def test_numpy_fallback_on_eager(self):
2641*da0073e9SAndroid Build Coastguard Worker        def fn():
2642*da0073e9SAndroid Build Coastguard Worker            return np.asarray(["L", "U"])
2643*da0073e9SAndroid Build Coastguard Worker
2644*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2645*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2646*da0073e9SAndroid Build Coastguard Worker
2647*da0073e9SAndroid Build Coastguard Worker        r = opt_fn()
2648*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 0)  # graph break
2649*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, np.asarray(["L", "U"]))
2650*da0073e9SAndroid Build Coastguard Worker
2651*da0073e9SAndroid Build Coastguard Worker        # repeat with a different function
2652*da0073e9SAndroid Build Coastguard Worker        def fn2():
2653*da0073e9SAndroid Build Coastguard Worker            return np.random.choice(["L", "U"])
2654*da0073e9SAndroid Build Coastguard Worker
2655*da0073e9SAndroid Build Coastguard Worker        cnts2 = torch._dynamo.testing.CompileCounter()
2656*da0073e9SAndroid Build Coastguard Worker        opt_fn2 = torch._dynamo.optimize(cnts2)(fn2)
2657*da0073e9SAndroid Build Coastguard Worker
2658*da0073e9SAndroid Build Coastguard Worker        r2 = fn2()
2659*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 0)
2660*da0073e9SAndroid Build Coastguard Worker        assert r2 in ("L", "U")
2661*da0073e9SAndroid Build Coastguard Worker
2662*da0073e9SAndroid Build Coastguard Worker    def test_trace_ndarray_frame(self):
2663*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2664*da0073e9SAndroid Build Coastguard Worker            x = x**2
2665*da0073e9SAndroid Build Coastguard Worker            print("graph break.")
2666*da0073e9SAndroid Build Coastguard Worker            return 2 * x
2667*da0073e9SAndroid Build Coastguard Worker
2668*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
2669*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(counter)(fn)
2670*da0073e9SAndroid Build Coastguard Worker
2671*da0073e9SAndroid Build Coastguard Worker        x = np.arange(8)
2672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), compiled_fn(x))
2673*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
2674*da0073e9SAndroid Build Coastguard Worker
2675*da0073e9SAndroid Build Coastguard Worker    def test_trace_ndarray_frame_2(self):
2676*da0073e9SAndroid Build Coastguard Worker        # no tensors/ndarray as inputs in the frame
2677*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2678*da0073e9SAndroid Build Coastguard Worker            print("graph break.")
2679*da0073e9SAndroid Build Coastguard Worker            return 2 * np.arange(x)
2680*da0073e9SAndroid Build Coastguard Worker
2681*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
2682*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(counter)(fn)
2683*da0073e9SAndroid Build Coastguard Worker
2684*da0073e9SAndroid Build Coastguard Worker        x = 8
2685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), compiled_fn(x))
2686*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
2687*da0073e9SAndroid Build Coastguard Worker
2688*da0073e9SAndroid Build Coastguard Worker    def test_numpy_non_torch_dtype(self):
2689*da0073e9SAndroid Build Coastguard Worker        # test that we gracefully graph break on dtypes
2690*da0073e9SAndroid Build Coastguard Worker        # that do not have pytorch equivalents.
2691*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2692*da0073e9SAndroid Build Coastguard Worker            return isinstance(x, torch.Tensor)
2693*da0073e9SAndroid Build Coastguard Worker
2694*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2695*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2696*da0073e9SAndroid Build Coastguard Worker
2697*da0073e9SAndroid Build Coastguard Worker        # torch does not have the `uint16` dtype
2698*da0073e9SAndroid Build Coastguard Worker        for x in [np.array([42], dtype=np.uint16), np.uint16(42), np.dtype("uint16")]:
2699*da0073e9SAndroid Build Coastguard Worker            r = opt_fn(x)
2700*da0073e9SAndroid Build Coastguard Worker
2701*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r, False)
2702*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 0)  # graph break
2703*da0073e9SAndroid Build Coastguard Worker
2704*da0073e9SAndroid Build Coastguard Worker    def test_numpy_iter(self):
2705*da0073e9SAndroid Build Coastguard Worker        # test that iteration over an ndarray produces ndarrays not bare tensors
2706*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2707*da0073e9SAndroid Build Coastguard Worker            return [bm for bm in x]
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2710*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2711*da0073e9SAndroid Build Coastguard Worker
2712*da0073e9SAndroid Build Coastguard Worker        proba_map = np.arange(3)[:, None]
2713*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(proba_map)
2714*da0073e9SAndroid Build Coastguard Worker
2715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual([type(r) for r in res], [np.ndarray, np.ndarray, np.ndarray])
2716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, [np.array([0]), np.array([1]), np.array([2])])
2717*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker    # cache size limit needs to be larger than the `dtypes` list size
2720*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(cache_size_limit=12)
2721*da0073e9SAndroid Build Coastguard Worker    def test_dtypes_no_graphbreaks(self):
2722*da0073e9SAndroid Build Coastguard Worker        dtypes = [
2723*da0073e9SAndroid Build Coastguard Worker            # floats
2724*da0073e9SAndroid Build Coastguard Worker            float,
2725*da0073e9SAndroid Build Coastguard Worker            np.float64,
2726*da0073e9SAndroid Build Coastguard Worker            "float64",
2727*da0073e9SAndroid Build Coastguard Worker            np.float32,
2728*da0073e9SAndroid Build Coastguard Worker            "float32",
2729*da0073e9SAndroid Build Coastguard Worker            # np.dtype('float64')   # XXX: this is not supported, yet
2730*da0073e9SAndroid Build Coastguard Worker            # integers
2731*da0073e9SAndroid Build Coastguard Worker            int,
2732*da0073e9SAndroid Build Coastguard Worker            "int",
2733*da0073e9SAndroid Build Coastguard Worker            np.intp,
2734*da0073e9SAndroid Build Coastguard Worker            np.int32,
2735*da0073e9SAndroid Build Coastguard Worker            np.uint8
2736*da0073e9SAndroid Build Coastguard Worker            # np.dtype('int')       # XXX: as above
2737*da0073e9SAndroid Build Coastguard Worker        ]
2738*da0073e9SAndroid Build Coastguard Worker
2739*da0073e9SAndroid Build Coastguard Worker        def fn(dt):
2740*da0073e9SAndroid Build Coastguard Worker            return np.arange(5, dtype=dt)
2741*da0073e9SAndroid Build Coastguard Worker
2742*da0073e9SAndroid Build Coastguard Worker        for dtyp in dtypes:
2743*da0073e9SAndroid Build Coastguard Worker            cnts = torch._dynamo.testing.CompileCounter()
2744*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize(cnts)(fn)
2745*da0073e9SAndroid Build Coastguard Worker
2746*da0073e9SAndroid Build Coastguard Worker            val = fn(dtyp)
2747*da0073e9SAndroid Build Coastguard Worker            opt_val = opt_fn(dtyp)
2748*da0073e9SAndroid Build Coastguard Worker
2749*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 1)  # no graph break
2750*da0073e9SAndroid Build Coastguard Worker
2751*da0073e9SAndroid Build Coastguard Worker    # setting the config value makes the PRNG identical to numpy's
2752*da0073e9SAndroid Build Coastguard Worker    # NB this may involve a graph break
2753*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(use_numpy_random_stream=True)
2754*da0073e9SAndroid Build Coastguard Worker    def test_numpy_random_config_to_numpy(self):
2755*da0073e9SAndroid Build Coastguard Worker        @torch.compile
2756*da0073e9SAndroid Build Coastguard Worker        def fn():
2757*da0073e9SAndroid Build Coastguard Worker            return np.random.uniform(size=13)
2758*da0073e9SAndroid Build Coastguard Worker
2759*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn().shape, (13,))
2760*da0073e9SAndroid Build Coastguard Worker
2761*da0073e9SAndroid Build Coastguard Worker    def test_inplace_view_on_graph_input(self):
2762*da0073e9SAndroid Build Coastguard Worker        # graph break when calling methods with inplace_view tag on graph input
2763*da0073e9SAndroid Build Coastguard Worker        func_args_map = {
2764*da0073e9SAndroid Build Coastguard Worker            lambda x: x.resize_(6).mul_(2): torch.ones(4),
2765*da0073e9SAndroid Build Coastguard Worker            lambda x: x.t_().mul_(2): torch.rand(2, 3),
2766*da0073e9SAndroid Build Coastguard Worker            lambda x: x.transpose_(0, 1).mul_(2): torch.rand(2, 3),
2767*da0073e9SAndroid Build Coastguard Worker            lambda x: x.squeeze_().mul_(2): torch.rand(1, 2, 3),
2768*da0073e9SAndroid Build Coastguard Worker            lambda x: x.unsqueeze_(0).mul_(2): torch.rand(2, 3),
2769*da0073e9SAndroid Build Coastguard Worker            lambda x: x.resize_as_(torch.rand(200, 300)): torch.rand(2, 3),
2770*da0073e9SAndroid Build Coastguard Worker            lambda x: x.swapaxes_(0, 1).mul_(2): torch.rand(2, 3),
2771*da0073e9SAndroid Build Coastguard Worker            lambda x: x.swapdims_(0, 1).mul_(2): torch.rand(2, 3),
2772*da0073e9SAndroid Build Coastguard Worker            lambda x: x.rename_("N", "C").mul_(2): torch.zeros(2, 3),
2773*da0073e9SAndroid Build Coastguard Worker            lambda x: x.as_strided_((3, 2), (2, 1)).mul_(2): torch.zeros(2, 3),
2774*da0073e9SAndroid Build Coastguard Worker            lambda x: x.detach_().mul_(2): torch.zeros(2, 3),
2775*da0073e9SAndroid Build Coastguard Worker        }
2776*da0073e9SAndroid Build Coastguard Worker        for func, args in func_args_map.items():
2777*da0073e9SAndroid Build Coastguard Worker            args_clone = args.clone()
2778*da0073e9SAndroid Build Coastguard Worker            cnts = torch._dynamo.testing.CompileCounter()
2779*da0073e9SAndroid Build Coastguard Worker            opt_f = torch._dynamo.optimize(cnts)(func)
2780*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(func(args).shape, opt_f(args_clone).shape))
2781*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 1)
2782*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.op_count, 1)  # mul_
2783*da0073e9SAndroid Build Coastguard Worker
2784*da0073e9SAndroid Build Coastguard Worker    def test_out_variants_with_resizing_on_graph_inputs(self):
2785*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
2786*da0073e9SAndroid Build Coastguard Worker            return torch.cosh(x, out=y) + 1
2787*da0073e9SAndroid Build Coastguard Worker
2788*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3)
2789*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(4)
2790*da0073e9SAndroid Build Coastguard Worker
2791*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2792*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend=cnts)
2793*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(x, y), opt_fn(x.clone(), y.clone())))
2794*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2795*da0073e9SAndroid Build Coastguard Worker
2796*da0073e9SAndroid Build Coastguard Worker    def test_out_variants_with_resizing_on_graph_inputs_with_dynamic(self):
2797*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/120482
2798*da0073e9SAndroid Build Coastguard Worker        class CustomModel(torch.nn.Module):
2799*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
2800*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2801*da0073e9SAndroid Build Coastguard Worker
2802*da0073e9SAndroid Build Coastguard Worker            def forward(self, inputs):
2803*da0073e9SAndroid Build Coastguard Worker                return torch.outer(**inputs)
2804*da0073e9SAndroid Build Coastguard Worker
2805*da0073e9SAndroid Build Coastguard Worker        compile_fn = torch.compile(CustomModel(), fullgraph=True)
2806*da0073e9SAndroid Build Coastguard Worker
2807*da0073e9SAndroid Build Coastguard Worker        shapes = [(2, 1), (6, 1), (4, 1)]
2808*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
2809*da0073e9SAndroid Build Coastguard Worker            vec1, vec2 = shape
2810*da0073e9SAndroid Build Coastguard Worker            input_tensor1 = torch.randn(vec1)
2811*da0073e9SAndroid Build Coastguard Worker            input_tensor2 = torch.randn(vec2)
2812*da0073e9SAndroid Build Coastguard Worker            out_tensor = torch.empty(shape)
2813*da0073e9SAndroid Build Coastguard Worker            args = {"input": input_tensor1, "vec2": input_tensor2, "out": out_tensor}
2814*da0073e9SAndroid Build Coastguard Worker            res = compile_fn(args)
2815*da0073e9SAndroid Build Coastguard Worker            opt_res = res.clone()  # cuz this is out and we mutate it
2816*da0073e9SAndroid Build Coastguard Worker            res = CustomModel()(args)
2817*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, opt_res)
2818*da0073e9SAndroid Build Coastguard Worker
2819*da0073e9SAndroid Build Coastguard Worker    def test_dict_mutation_side_effect(self):
2820*da0073e9SAndroid Build Coastguard Worker        def fn(d):
2821*da0073e9SAndroid Build Coastguard Worker            d["c"] = d["a"] + d.pop("b")
2822*da0073e9SAndroid Build Coastguard Worker            return d
2823*da0073e9SAndroid Build Coastguard Worker
2824*da0073e9SAndroid Build Coastguard Worker        args1 = {"a": torch.randn(10), "b": torch.randn(10)}
2825*da0073e9SAndroid Build Coastguard Worker        args2 = dict(args1)
2826*da0073e9SAndroid Build Coastguard Worker        assert fn(args1) is args1
2827*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2828*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2829*da0073e9SAndroid Build Coastguard Worker        self.assertIs(opt_fn(args2), args2)
2830*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(args1, args2))
2831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2832*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 1)
2833*da0073e9SAndroid Build Coastguard Worker
2834*da0073e9SAndroid Build Coastguard Worker    def test_dict_order_keys(self):
2835*da0073e9SAndroid Build Coastguard Worker        def fn(d):
2836*da0073e9SAndroid Build Coastguard Worker            c = 0
2837*da0073e9SAndroid Build Coastguard Worker            for v in d.values():
2838*da0073e9SAndroid Build Coastguard Worker                c += v
2839*da0073e9SAndroid Build Coastguard Worker            return c
2840*da0073e9SAndroid Build Coastguard Worker
2841*da0073e9SAndroid Build Coastguard Worker        args1 = {}
2842*da0073e9SAndroid Build Coastguard Worker        args1["a"] = torch.rand(10)
2843*da0073e9SAndroid Build Coastguard Worker        args1["b"] = torch.rand(10)
2844*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2845*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2846*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(args1), opt_fn(args1))
2847*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
2849*da0073e9SAndroid Build Coastguard Worker
2850*da0073e9SAndroid Build Coastguard Worker        # A different order of keys recompiles
2851*da0073e9SAndroid Build Coastguard Worker        args2 = {}
2852*da0073e9SAndroid Build Coastguard Worker        args2["b"] = args1["b"]
2853*da0073e9SAndroid Build Coastguard Worker        args2["a"] = args1["a"]
2854*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(args2), opt_fn(args2))
2855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
2856*da0073e9SAndroid Build Coastguard Worker        # Extra calls don't recompile
2857*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
2858*da0073e9SAndroid Build Coastguard Worker
2859*da0073e9SAndroid Build Coastguard Worker    def test_dict_namedtuple(self):
2860*da0073e9SAndroid Build Coastguard Worker        def fn(d):
2861*da0073e9SAndroid Build Coastguard Worker            return d[3] * 2
2862*da0073e9SAndroid Build Coastguard Worker
2863*da0073e9SAndroid Build Coastguard Worker        args1 = {collections.namedtuple: None, 3: torch.randn(3)}
2864*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2865*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(args1), opt_fn(args1))
2867*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2868*da0073e9SAndroid Build Coastguard Worker        # Test a failing namedtuple guard
2869*da0073e9SAndroid Build Coastguard Worker        args2 = {2: None, 3: torch.randn(3)}
2870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(args2), opt_fn(args2))
2871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
2872*da0073e9SAndroid Build Coastguard Worker
2873*da0073e9SAndroid Build Coastguard Worker    def test_dict_order_keys_tensors(self):
2874*da0073e9SAndroid Build Coastguard Worker        def fn(d, x):
2875*da0073e9SAndroid Build Coastguard Worker            return d[x] + 3
2876*da0073e9SAndroid Build Coastguard Worker
2877*da0073e9SAndroid Build Coastguard Worker        args1 = {}
2878*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
2879*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
2880*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(10)
2881*da0073e9SAndroid Build Coastguard Worker        args1[x] = y
2882*da0073e9SAndroid Build Coastguard Worker        args1[3] = z
2883*da0073e9SAndroid Build Coastguard Worker
2884*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2885*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2886*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(args1, x), opt_fn(args1, x))
2887*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2888*da0073e9SAndroid Build Coastguard Worker
2889*da0073e9SAndroid Build Coastguard Worker        # Calling again doesn't recompile (same id and key order)
2890*da0073e9SAndroid Build Coastguard Worker        opt_fn(args1, x)
2891*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2892*da0073e9SAndroid Build Coastguard Worker        args2 = {}
2893*da0073e9SAndroid Build Coastguard Worker        args2[3] = z
2894*da0073e9SAndroid Build Coastguard Worker        args2[x] = y
2895*da0073e9SAndroid Build Coastguard Worker
2896*da0073e9SAndroid Build Coastguard Worker        # Different order recompiles
2897*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(args2, x), opt_fn(args2, x))
2898*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
2899*da0073e9SAndroid Build Coastguard Worker
2900*da0073e9SAndroid Build Coastguard Worker    def test_dict_order_keys_modules(self):
2901*da0073e9SAndroid Build Coastguard Worker        def fn(d, x):
2902*da0073e9SAndroid Build Coastguard Worker            return d[x](torch.ones(2, 2))
2903*da0073e9SAndroid Build Coastguard Worker
2904*da0073e9SAndroid Build Coastguard Worker        args1 = {}
2905*da0073e9SAndroid Build Coastguard Worker        x = torch.nn.Linear(2, 2)
2906*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.Linear(2, 2)
2907*da0073e9SAndroid Build Coastguard Worker        z = torch.nn.Linear(2, 2)
2908*da0073e9SAndroid Build Coastguard Worker        args1[x] = y
2909*da0073e9SAndroid Build Coastguard Worker        args1[3] = z
2910*da0073e9SAndroid Build Coastguard Worker
2911*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2912*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
2913*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(args1, x), opt_fn(args1, x))
2914*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2915*da0073e9SAndroid Build Coastguard Worker
2916*da0073e9SAndroid Build Coastguard Worker        # Calling again doesn't recompile (same id and key order)
2917*da0073e9SAndroid Build Coastguard Worker        opt_fn(args1, x)
2918*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2919*da0073e9SAndroid Build Coastguard Worker        args2 = {}
2920*da0073e9SAndroid Build Coastguard Worker        args2[3] = z
2921*da0073e9SAndroid Build Coastguard Worker        args2[x] = y
2922*da0073e9SAndroid Build Coastguard Worker
2923*da0073e9SAndroid Build Coastguard Worker        # Different order recompiles
2924*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(args2, x), opt_fn(args2, x))
2925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
2926*da0073e9SAndroid Build Coastguard Worker
2927*da0073e9SAndroid Build Coastguard Worker    def test_dunder_new_function_inlining(self):
2928*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/107460
2929*da0073e9SAndroid Build Coastguard Worker
2930*da0073e9SAndroid Build Coastguard Worker        counters.clear()
2931*da0073e9SAndroid Build Coastguard Worker
2932*da0073e9SAndroid Build Coastguard Worker        class ModelA(torch.nn.Module):
2933*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
2934*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2935*da0073e9SAndroid Build Coastguard Worker
2936*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2937*da0073e9SAndroid Build Coastguard Worker                return torch.tanh(x + 1)
2938*da0073e9SAndroid Build Coastguard Worker
2939*da0073e9SAndroid Build Coastguard Worker        class ModelB(torch.nn.Module):
2940*da0073e9SAndroid Build Coastguard Worker            def __new__(cls):
2941*da0073e9SAndroid Build Coastguard Worker                return ModelA()
2942*da0073e9SAndroid Build Coastguard Worker
2943*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
2944*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
2945*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2946*da0073e9SAndroid Build Coastguard Worker                self.layer = torch.nn.Linear(2, 2)
2947*da0073e9SAndroid Build Coastguard Worker
2948*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2949*da0073e9SAndroid Build Coastguard Worker                other = ModelB()
2950*da0073e9SAndroid Build Coastguard Worker                return self.layer(x) + other(x)
2951*da0073e9SAndroid Build Coastguard Worker
2952*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 2)
2953*da0073e9SAndroid Build Coastguard Worker        m = Model()
2954*da0073e9SAndroid Build Coastguard Worker
2955*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager")(m)
2956*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
2957*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
2958*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2959*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 1)
2960*da0073e9SAndroid Build Coastguard Worker        self.assertFalse("super() nn.Module.__init__" in counters["graph_break"])
2961*da0073e9SAndroid Build Coastguard Worker
2962*da0073e9SAndroid Build Coastguard Worker    def test_class_duner_mro(self):
2963*da0073e9SAndroid Build Coastguard Worker        class ModuleA(torch.nn.Module):
2964*da0073e9SAndroid Build Coastguard Worker            pass
2965*da0073e9SAndroid Build Coastguard Worker
2966*da0073e9SAndroid Build Coastguard Worker        class ModuleB(ModuleA):
2967*da0073e9SAndroid Build Coastguard Worker            pass
2968*da0073e9SAndroid Build Coastguard Worker
2969*da0073e9SAndroid Build Coastguard Worker        def fn(x, mod):
2970*da0073e9SAndroid Build Coastguard Worker            if ModuleA in type(mod).__mro__:
2971*da0073e9SAndroid Build Coastguard Worker                return x + 1
2972*da0073e9SAndroid Build Coastguard Worker            else:
2973*da0073e9SAndroid Build Coastguard Worker                return x - 1
2974*da0073e9SAndroid Build Coastguard Worker
2975*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3)
2976*da0073e9SAndroid Build Coastguard Worker        mod = ModuleB()
2977*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
2978*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, mod)
2979*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, mod)
2980*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2981*da0073e9SAndroid Build Coastguard Worker
2982*da0073e9SAndroid Build Coastguard Worker    def test_nested_wraps(self):
2983*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
2984*da0073e9SAndroid Build Coastguard Worker            def add(x, y):
2985*da0073e9SAndroid Build Coastguard Worker                return x + y
2986*da0073e9SAndroid Build Coastguard Worker
2987*da0073e9SAndroid Build Coastguard Worker            @functools.wraps(add)
2988*da0073e9SAndroid Build Coastguard Worker            def wrapped_call(x, y):
2989*da0073e9SAndroid Build Coastguard Worker                return add(x, y)
2990*da0073e9SAndroid Build Coastguard Worker
2991*da0073e9SAndroid Build Coastguard Worker            return wrapped_call(x, y)
2992*da0073e9SAndroid Build Coastguard Worker
2993*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
2994*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 3)
2995*da0073e9SAndroid Build Coastguard Worker
2996*da0073e9SAndroid Build Coastguard Worker        o = torch.compile(foo, fullgraph=True, backend="eager")(x, y)
2997*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o, x + y)
2998*da0073e9SAndroid Build Coastguard Worker
2999*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
3000*da0073e9SAndroid Build Coastguard Worker            def nested_call(x, y):
3001*da0073e9SAndroid Build Coastguard Worker                def mul(x, y):
3002*da0073e9SAndroid Build Coastguard Worker                    return x * y
3003*da0073e9SAndroid Build Coastguard Worker
3004*da0073e9SAndroid Build Coastguard Worker                @functools.wraps(mul)
3005*da0073e9SAndroid Build Coastguard Worker                def double_nested_call(x, y):
3006*da0073e9SAndroid Build Coastguard Worker                    return mul(x, y)
3007*da0073e9SAndroid Build Coastguard Worker
3008*da0073e9SAndroid Build Coastguard Worker                return double_nested_call(x, y)
3009*da0073e9SAndroid Build Coastguard Worker
3010*da0073e9SAndroid Build Coastguard Worker            return nested_call(x, y)
3011*da0073e9SAndroid Build Coastguard Worker
3012*da0073e9SAndroid Build Coastguard Worker        o = torch.compile(foo, fullgraph=True, backend="eager")(x, y)
3013*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o, x * y)
3014*da0073e9SAndroid Build Coastguard Worker
3015*da0073e9SAndroid Build Coastguard Worker    def test_module_deepcopy(self):
3016*da0073e9SAndroid Build Coastguard Worker        m1 = torch.nn.Sequential(
3017*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
3018*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
3019*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
3020*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
3021*da0073e9SAndroid Build Coastguard Worker        )
3022*da0073e9SAndroid Build Coastguard Worker        m2 = torch.nn.Sequential(
3023*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
3024*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
3025*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
3026*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
3027*da0073e9SAndroid Build Coastguard Worker        )
3028*da0073e9SAndroid Build Coastguard Worker
3029*da0073e9SAndroid Build Coastguard Worker        def fn(m, x):
3030*da0073e9SAndroid Build Coastguard Worker            m_copy = copy.deepcopy(m)
3031*da0073e9SAndroid Build Coastguard Worker            return m_copy(x)
3032*da0073e9SAndroid Build Coastguard Worker
3033*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(10)
3034*da0073e9SAndroid Build Coastguard Worker        correct1 = fn(m1, v)
3035*da0073e9SAndroid Build Coastguard Worker        correct2 = fn(m2, v)
3036*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3037*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
3038*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
3039*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(opt_fn(m1, v), correct1))
3040*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
3041*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(opt_fn(m2, v), correct2))
3042*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3043*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 4)
3044*da0073e9SAndroid Build Coastguard Worker
3045*da0073e9SAndroid Build Coastguard Worker    def test_type_copy(self):
3046*da0073e9SAndroid Build Coastguard Worker        def fn(seq):
3047*da0073e9SAndroid Build Coastguard Worker            a, b = seq
3048*da0073e9SAndroid Build Coastguard Worker            return type(seq)([a + 1, b + 2, a + b])
3049*da0073e9SAndroid Build Coastguard Worker
3050*da0073e9SAndroid Build Coastguard Worker        args1 = [torch.randn(10), torch.randn(10)]
3051*da0073e9SAndroid Build Coastguard Worker        args2 = (torch.randn(10), torch.randn(10))
3052*da0073e9SAndroid Build Coastguard Worker        correct1 = fn(args1)
3053*da0073e9SAndroid Build Coastguard Worker        correct2 = fn(args2)
3054*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3055*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
3056*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(args1), correct1))
3057*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(args2), correct2))
3058*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(opt_fn(args1), list)
3059*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(opt_fn(args2), tuple)
3060*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
3061*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 6)
3062*da0073e9SAndroid Build Coastguard Worker
3063*da0073e9SAndroid Build Coastguard Worker    def test_setattr_mutation1(self):
3064*da0073e9SAndroid Build Coastguard Worker        class MyObj:  # noqa: B903
3065*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a, b):
3066*da0073e9SAndroid Build Coastguard Worker                self.a = a
3067*da0073e9SAndroid Build Coastguard Worker                self.b = b
3068*da0073e9SAndroid Build Coastguard Worker
3069*da0073e9SAndroid Build Coastguard Worker        def fn(obj):
3070*da0073e9SAndroid Build Coastguard Worker            obj.c = obj.a * obj.b + 1
3071*da0073e9SAndroid Build Coastguard Worker            obj.b = obj.a * obj.c + 2
3072*da0073e9SAndroid Build Coastguard Worker            obj.a = obj.b * obj.c + 3
3073*da0073e9SAndroid Build Coastguard Worker            obj.c = obj.a * obj.b + 4
3074*da0073e9SAndroid Build Coastguard Worker            obj.b = obj.a * obj.c + 5
3075*da0073e9SAndroid Build Coastguard Worker            obj.a = obj.b * obj.c + 6
3076*da0073e9SAndroid Build Coastguard Worker            return obj
3077*da0073e9SAndroid Build Coastguard Worker
3078*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(10)
3079*da0073e9SAndroid Build Coastguard Worker        x2 = torch.randn(10)
3080*da0073e9SAndroid Build Coastguard Worker        obj1 = MyObj(x1, x2)
3081*da0073e9SAndroid Build Coastguard Worker        obj2 = MyObj(x1, x2)
3082*da0073e9SAndroid Build Coastguard Worker        fn(obj2)
3083*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3084*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
3085*da0073e9SAndroid Build Coastguard Worker        self.assertIs(opt_fn(obj1), obj1)
3086*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj1.a, obj2.a))
3087*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj1.b, obj2.b))
3088*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj1.c, obj2.c))
3089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3090*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 12)
3091*da0073e9SAndroid Build Coastguard Worker
3092*da0073e9SAndroid Build Coastguard Worker    def test_setattr_mutation2(self):
3093*da0073e9SAndroid Build Coastguard Worker        class MyObj:
3094*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
3095*da0073e9SAndroid Build Coastguard Worker                self.a = x + 1
3096*da0073e9SAndroid Build Coastguard Worker                self.b = x + 2
3097*da0073e9SAndroid Build Coastguard Worker
3098*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3099*da0073e9SAndroid Build Coastguard Worker            x = x / 3.0
3100*da0073e9SAndroid Build Coastguard Worker            obj = MyObj(x)
3101*da0073e9SAndroid Build Coastguard Worker            obj.c = obj.a * obj.b + 1
3102*da0073e9SAndroid Build Coastguard Worker            obj.b = obj.a * obj.c + 2
3103*da0073e9SAndroid Build Coastguard Worker            obj.a = obj.b * obj.c + 3
3104*da0073e9SAndroid Build Coastguard Worker            return obj
3105*da0073e9SAndroid Build Coastguard Worker
3106*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(10)
3107*da0073e9SAndroid Build Coastguard Worker        obj2 = fn(x1)
3108*da0073e9SAndroid Build Coastguard Worker
3109*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3110*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
3111*da0073e9SAndroid Build Coastguard Worker        obj1 = opt_fn(x1)
3112*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj1.a, obj2.a))
3113*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj1.b, obj2.b))
3114*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj1.c, obj2.c))
3115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3116*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 9)
3117*da0073e9SAndroid Build Coastguard Worker
3118*da0073e9SAndroid Build Coastguard Worker    def test_setattr_mutation3(self):
3119*da0073e9SAndroid Build Coastguard Worker        # TODO(jansel): dead code eliminate the object creation
3120*da0073e9SAndroid Build Coastguard Worker        class MyObj:
3121*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
3122*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3123*da0073e9SAndroid Build Coastguard Worker                self.a = x + 1
3124*da0073e9SAndroid Build Coastguard Worker                self.b = x + 2
3125*da0073e9SAndroid Build Coastguard Worker
3126*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3127*da0073e9SAndroid Build Coastguard Worker            x = x / 3.0
3128*da0073e9SAndroid Build Coastguard Worker            obj = MyObj(x)
3129*da0073e9SAndroid Build Coastguard Worker            obj.c = obj.a * obj.b + 1
3130*da0073e9SAndroid Build Coastguard Worker            obj.b = obj.a * obj.c + 2
3131*da0073e9SAndroid Build Coastguard Worker            obj.a = obj.b * obj.c + 3
3132*da0073e9SAndroid Build Coastguard Worker            return obj.a, obj.b, obj.c
3133*da0073e9SAndroid Build Coastguard Worker
3134*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(10)
3135*da0073e9SAndroid Build Coastguard Worker        obj2 = fn(x1)
3136*da0073e9SAndroid Build Coastguard Worker
3137*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3138*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
3139*da0073e9SAndroid Build Coastguard Worker        obj1 = opt_fn(x1)
3140*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj1, obj2))
3141*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 9)
3143*da0073e9SAndroid Build Coastguard Worker
3144*da0073e9SAndroid Build Coastguard Worker    def test_object_setattr(self):
3145*da0073e9SAndroid Build Coastguard Worker        @dataclasses.dataclass
3146*da0073e9SAndroid Build Coastguard Worker        class A:
3147*da0073e9SAndroid Build Coastguard Worker            x: torch.Tensor
3148*da0073e9SAndroid Build Coastguard Worker
3149*da0073e9SAndroid Build Coastguard Worker        def fn1(x) -> None:
3150*da0073e9SAndroid Build Coastguard Worker            a = A(x)
3151*da0073e9SAndroid Build Coastguard Worker            object.__setattr__(a, "x", x + 2)
3152*da0073e9SAndroid Build Coastguard Worker            return a
3153*da0073e9SAndroid Build Coastguard Worker
3154*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(10)
3155*da0073e9SAndroid Build Coastguard Worker        obj11 = fn1(x1.clone())
3156*da0073e9SAndroid Build Coastguard Worker
3157*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3158*da0073e9SAndroid Build Coastguard Worker        opt_fn1 = torch._dynamo.optimize(cnts, nopython=True)(fn1)
3159*da0073e9SAndroid Build Coastguard Worker        obj12 = opt_fn1(x1.clone())
3160*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj11.x, x1 + 2))
3161*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj12.x, x1 + 2))
3162*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj11.x, obj12.x))
3163*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3164*da0073e9SAndroid Build Coastguard Worker
3165*da0073e9SAndroid Build Coastguard Worker        @dataclasses.dataclass(frozen=True)
3166*da0073e9SAndroid Build Coastguard Worker        class B:
3167*da0073e9SAndroid Build Coastguard Worker            x: torch.Tensor
3168*da0073e9SAndroid Build Coastguard Worker
3169*da0073e9SAndroid Build Coastguard Worker        def fn2(x) -> None:
3170*da0073e9SAndroid Build Coastguard Worker            b = B(x)
3171*da0073e9SAndroid Build Coastguard Worker            return b
3172*da0073e9SAndroid Build Coastguard Worker
3173*da0073e9SAndroid Build Coastguard Worker        x2 = torch.randn(10)
3174*da0073e9SAndroid Build Coastguard Worker        obj21 = fn2(x2.clone())
3175*da0073e9SAndroid Build Coastguard Worker
3176*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3177*da0073e9SAndroid Build Coastguard Worker        opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
3178*da0073e9SAndroid Build Coastguard Worker        obj22 = opt_fn2(x2.clone())
3179*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj21.x, x2))
3180*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj22.x, x2))
3181*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj21.x, obj22.x))
3182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 0)
3183*da0073e9SAndroid Build Coastguard Worker
3184*da0073e9SAndroid Build Coastguard Worker        @dataclasses.dataclass(frozen=True)
3185*da0073e9SAndroid Build Coastguard Worker        class C:
3186*da0073e9SAndroid Build Coastguard Worker            x: torch.Tensor
3187*da0073e9SAndroid Build Coastguard Worker
3188*da0073e9SAndroid Build Coastguard Worker        def fn3(x) -> None:
3189*da0073e9SAndroid Build Coastguard Worker            c = C(x)
3190*da0073e9SAndroid Build Coastguard Worker            object.__setattr__(c, "x", x + 2)
3191*da0073e9SAndroid Build Coastguard Worker            return c
3192*da0073e9SAndroid Build Coastguard Worker
3193*da0073e9SAndroid Build Coastguard Worker        x3 = torch.randn(10)
3194*da0073e9SAndroid Build Coastguard Worker        obj31 = fn3(x3.clone())
3195*da0073e9SAndroid Build Coastguard Worker
3196*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3197*da0073e9SAndroid Build Coastguard Worker        opt_fn3 = torch._dynamo.optimize(cnts, nopython=True)(fn3)
3198*da0073e9SAndroid Build Coastguard Worker        obj32 = opt_fn3(x3.clone())
3199*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj31.x, x3 + 2))
3200*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj32.x, x3 + 2))
3201*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj31.x, obj32.x))
3202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3203*da0073e9SAndroid Build Coastguard Worker
3204*da0073e9SAndroid Build Coastguard Worker        @dataclasses.dataclass(frozen=True)
3205*da0073e9SAndroid Build Coastguard Worker        class D:
3206*da0073e9SAndroid Build Coastguard Worker            x: torch.Tensor
3207*da0073e9SAndroid Build Coastguard Worker
3208*da0073e9SAndroid Build Coastguard Worker            def __post_init__(self):
3209*da0073e9SAndroid Build Coastguard Worker                object.__setattr__(self, "y", self.x + 2)
3210*da0073e9SAndroid Build Coastguard Worker
3211*da0073e9SAndroid Build Coastguard Worker        def fn4(x) -> None:
3212*da0073e9SAndroid Build Coastguard Worker            d = D(x)
3213*da0073e9SAndroid Build Coastguard Worker            return d
3214*da0073e9SAndroid Build Coastguard Worker
3215*da0073e9SAndroid Build Coastguard Worker        x4 = torch.randn(10)
3216*da0073e9SAndroid Build Coastguard Worker        obj41 = fn4(x4.clone())
3217*da0073e9SAndroid Build Coastguard Worker
3218*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3219*da0073e9SAndroid Build Coastguard Worker        opt_fn4 = torch._dynamo.optimize(cnts, nopython=True)(fn4)
3220*da0073e9SAndroid Build Coastguard Worker        obj42 = opt_fn4(x4.clone())
3221*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj41.x, x4))
3222*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj42.x, x4))
3223*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj41.x, obj42.x))
3224*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj41.y, x4 + 2))
3225*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj42.y, x4 + 2))
3226*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj41.y, obj42.y))
3227*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3228*da0073e9SAndroid Build Coastguard Worker
3229*da0073e9SAndroid Build Coastguard Worker    def test_user_defined_class_name(self):
3230*da0073e9SAndroid Build Coastguard Worker        class MyClassFoo:
3231*da0073e9SAndroid Build Coastguard Worker            pass
3232*da0073e9SAndroid Build Coastguard Worker
3233*da0073e9SAndroid Build Coastguard Worker        def fn1(a, b, c):
3234*da0073e9SAndroid Build Coastguard Worker            tmp = MyClassFoo()
3235*da0073e9SAndroid Build Coastguard Worker            if tmp.__class__.__name__ == "MyClassFoo":
3236*da0073e9SAndroid Build Coastguard Worker                return a - b / c
3237*da0073e9SAndroid Build Coastguard Worker
3238*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn=fn1, nargs=3)
3239*da0073e9SAndroid Build Coastguard Worker
3240*da0073e9SAndroid Build Coastguard Worker    def test_user_defined_class_python_type(self):
3241*da0073e9SAndroid Build Coastguard Worker        class MyClass1:
3242*da0073e9SAndroid Build Coastguard Worker            pass
3243*da0073e9SAndroid Build Coastguard Worker
3244*da0073e9SAndroid Build Coastguard Worker        class ExampleMeta(type):
3245*da0073e9SAndroid Build Coastguard Worker            pass
3246*da0073e9SAndroid Build Coastguard Worker
3247*da0073e9SAndroid Build Coastguard Worker        class MyClass2(metaclass=ExampleMeta):
3248*da0073e9SAndroid Build Coastguard Worker            pass
3249*da0073e9SAndroid Build Coastguard Worker
3250*da0073e9SAndroid Build Coastguard Worker        def fn(x, c):
3251*da0073e9SAndroid Build Coastguard Worker            if isinstance(c, MyClass1):
3252*da0073e9SAndroid Build Coastguard Worker                return x + 1
3253*da0073e9SAndroid Build Coastguard Worker            elif isinstance(c, MyClass2):
3254*da0073e9SAndroid Build Coastguard Worker                return x + 2
3255*da0073e9SAndroid Build Coastguard Worker            else:
3256*da0073e9SAndroid Build Coastguard Worker                return x + 3
3257*da0073e9SAndroid Build Coastguard Worker
3258*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3)
3259*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
3260*da0073e9SAndroid Build Coastguard Worker        for c in [MyClass1, MyClass2]:
3261*da0073e9SAndroid Build Coastguard Worker            ref = fn(x, c)
3262*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x, c)
3263*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref, res))
3264*da0073e9SAndroid Build Coastguard Worker
3265*da0073e9SAndroid Build Coastguard Worker    def test_super_calling_with_metaclass(self):
3266*da0073e9SAndroid Build Coastguard Worker        class ExampleMeta(type):
3267*da0073e9SAndroid Build Coastguard Worker            pass
3268*da0073e9SAndroid Build Coastguard Worker
3269*da0073e9SAndroid Build Coastguard Worker        class MyClass1(metaclass=ExampleMeta):
3270*da0073e9SAndroid Build Coastguard Worker            coeff = 4  # Force the constant guard to test source in guards
3271*da0073e9SAndroid Build Coastguard Worker
3272*da0073e9SAndroid Build Coastguard Worker            @classmethod
3273*da0073e9SAndroid Build Coastguard Worker            def add(cls, x):
3274*da0073e9SAndroid Build Coastguard Worker                return x + 1
3275*da0073e9SAndroid Build Coastguard Worker
3276*da0073e9SAndroid Build Coastguard Worker        class MyClass2(MyClass1):
3277*da0073e9SAndroid Build Coastguard Worker            @classmethod
3278*da0073e9SAndroid Build Coastguard Worker            def add(cls, x):
3279*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
3280*da0073e9SAndroid Build Coastguard Worker                return x + super().add(x) + super().coeff
3281*da0073e9SAndroid Build Coastguard Worker
3282*da0073e9SAndroid Build Coastguard Worker        def fn(x, obj):
3283*da0073e9SAndroid Build Coastguard Worker            return x + obj.add(x)
3284*da0073e9SAndroid Build Coastguard Worker
3285*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3)
3286*da0073e9SAndroid Build Coastguard Worker        obj = MyClass2()
3287*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
3288*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, obj)
3289*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, obj)
3290*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3291*da0073e9SAndroid Build Coastguard Worker
3292*da0073e9SAndroid Build Coastguard Worker    def test_usr_cls_staticmethod(self):
3293*da0073e9SAndroid Build Coastguard Worker        class Foo:
3294*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3295*da0073e9SAndroid Build Coastguard Worker            def bar(a, b):
3296*da0073e9SAndroid Build Coastguard Worker                return a + b
3297*da0073e9SAndroid Build Coastguard Worker
3298*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
3299*da0073e9SAndroid Build Coastguard Worker            return Foo.bar(a, b) - 1
3300*da0073e9SAndroid Build Coastguard Worker
3301*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn=fn, nargs=2)
3302*da0073e9SAndroid Build Coastguard Worker
3303*da0073e9SAndroid Build Coastguard Worker    def test_usr_cls_classmethod(self):
3304*da0073e9SAndroid Build Coastguard Worker        class Foo:
3305*da0073e9SAndroid Build Coastguard Worker            @classmethod
3306*da0073e9SAndroid Build Coastguard Worker            def bar(cls, a, b):
3307*da0073e9SAndroid Build Coastguard Worker                return a + b
3308*da0073e9SAndroid Build Coastguard Worker
3309*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
3310*da0073e9SAndroid Build Coastguard Worker            return Foo.bar(a, b) - 1
3311*da0073e9SAndroid Build Coastguard Worker
3312*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn=fn, nargs=2)
3313*da0073e9SAndroid Build Coastguard Worker
3314*da0073e9SAndroid Build Coastguard Worker    def test_dunder_methods(self):
3315*da0073e9SAndroid Build Coastguard Worker        class Foo:
3316*da0073e9SAndroid Build Coastguard Worker            def __init__(self, val):
3317*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3318*da0073e9SAndroid Build Coastguard Worker                self.val = val
3319*da0073e9SAndroid Build Coastguard Worker
3320*da0073e9SAndroid Build Coastguard Worker            def __add__(self, other):
3321*da0073e9SAndroid Build Coastguard Worker                return Foo(self.val + other.val)
3322*da0073e9SAndroid Build Coastguard Worker
3323*da0073e9SAndroid Build Coastguard Worker            def __mul__(self, other):
3324*da0073e9SAndroid Build Coastguard Worker                return Foo(self.val * other.val)
3325*da0073e9SAndroid Build Coastguard Worker
3326*da0073e9SAndroid Build Coastguard Worker            def __truediv__(self, other):
3327*da0073e9SAndroid Build Coastguard Worker                return Foo(self.val / other.val)
3328*da0073e9SAndroid Build Coastguard Worker
3329*da0073e9SAndroid Build Coastguard Worker            def __sub__(self, other):
3330*da0073e9SAndroid Build Coastguard Worker                return Foo(self.val - other.val)
3331*da0073e9SAndroid Build Coastguard Worker
3332*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c):
3333*da0073e9SAndroid Build Coastguard Worker            return Foo(a) + Foo(b) * Foo(c) / Foo(a) - Foo(b)
3334*da0073e9SAndroid Build Coastguard Worker
3335*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, fn=fn, nargs=3, expected_ops=4)
3336*da0073e9SAndroid Build Coastguard Worker
3337*da0073e9SAndroid Build Coastguard Worker    def test_function_annotation(self):
3338*da0073e9SAndroid Build Coastguard Worker        class Variable:
3339*da0073e9SAndroid Build Coastguard Worker            pass
3340*da0073e9SAndroid Build Coastguard Worker
3341*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3342*da0073e9SAndroid Build Coastguard Worker            x = x / 3.0
3343*da0073e9SAndroid Build Coastguard Worker
3344*da0073e9SAndroid Build Coastguard Worker            def inner(y: typing.List[Variable]):
3345*da0073e9SAndroid Build Coastguard Worker                return x + 1
3346*da0073e9SAndroid Build Coastguard Worker
3347*da0073e9SAndroid Build Coastguard Worker            return inner
3348*da0073e9SAndroid Build Coastguard Worker
3349*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(10)
3350*da0073e9SAndroid Build Coastguard Worker        obj2 = fn(x1)([])
3351*da0073e9SAndroid Build Coastguard Worker
3352*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3353*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
3354*da0073e9SAndroid Build Coastguard Worker        opt_fn_inner = torch._dynamo.optimize_assert(cnts)(opt_fn(x1))
3355*da0073e9SAndroid Build Coastguard Worker        obj1 = opt_fn_inner([])
3356*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(obj1, obj2))
3357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
3358*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
3359*da0073e9SAndroid Build Coastguard Worker
3360*da0073e9SAndroid Build Coastguard Worker    def test_nested_closure(self):
3361*da0073e9SAndroid Build Coastguard Worker        v0 = torch.randn(10)
3362*da0073e9SAndroid Build Coastguard Worker
3363*da0073e9SAndroid Build Coastguard Worker        def fn1():
3364*da0073e9SAndroid Build Coastguard Worker            v1 = torch.randn(10)
3365*da0073e9SAndroid Build Coastguard Worker
3366*da0073e9SAndroid Build Coastguard Worker            def fn2(*args, **kwargs):
3367*da0073e9SAndroid Build Coastguard Worker                assert len(args) == 1
3368*da0073e9SAndroid Build Coastguard Worker                assert len(kwargs) == 1
3369*da0073e9SAndroid Build Coastguard Worker                v2 = torch.randn(10) + args[0] + kwargs["b"]
3370*da0073e9SAndroid Build Coastguard Worker
3371*da0073e9SAndroid Build Coastguard Worker                def fn3(v3=torch.randn(10)):
3372*da0073e9SAndroid Build Coastguard Worker                    def fn4():
3373*da0073e9SAndroid Build Coastguard Worker                        return v0 + v1 + v2 + v3 + 1
3374*da0073e9SAndroid Build Coastguard Worker
3375*da0073e9SAndroid Build Coastguard Worker                    return fn4
3376*da0073e9SAndroid Build Coastguard Worker
3377*da0073e9SAndroid Build Coastguard Worker                return fn3
3378*da0073e9SAndroid Build Coastguard Worker
3379*da0073e9SAndroid Build Coastguard Worker            return fn2(1, b=2)()
3380*da0073e9SAndroid Build Coastguard Worker
3381*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3382*da0073e9SAndroid Build Coastguard Worker        opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1)
3383*da0073e9SAndroid Build Coastguard Worker        tmp1 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
3384*da0073e9SAndroid Build Coastguard Worker        tmp2 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
3385*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(tmp1().shape, (10,))
3386*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(tmp1(), tmp1()))
3387*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(same(tmp1(), tmp2()))
3388*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
3389*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 9)
3390*da0073e9SAndroid Build Coastguard Worker
3391*da0073e9SAndroid Build Coastguard Worker    def test_nested_closure_mutation(self):
3392*da0073e9SAndroid Build Coastguard Worker        def fn1():
3393*da0073e9SAndroid Build Coastguard Worker            v1 = torch.randn(10)
3394*da0073e9SAndroid Build Coastguard Worker
3395*da0073e9SAndroid Build Coastguard Worker            def fn2():
3396*da0073e9SAndroid Build Coastguard Worker                v2 = torch.randn(10)
3397*da0073e9SAndroid Build Coastguard Worker
3398*da0073e9SAndroid Build Coastguard Worker                def fn3():
3399*da0073e9SAndroid Build Coastguard Worker                    nonlocal v1, v2
3400*da0073e9SAndroid Build Coastguard Worker                    v1 += 1
3401*da0073e9SAndroid Build Coastguard Worker                    v2 += 2
3402*da0073e9SAndroid Build Coastguard Worker                    return v1 + v2
3403*da0073e9SAndroid Build Coastguard Worker
3404*da0073e9SAndroid Build Coastguard Worker                return fn3
3405*da0073e9SAndroid Build Coastguard Worker
3406*da0073e9SAndroid Build Coastguard Worker            rv = fn2()
3407*da0073e9SAndroid Build Coastguard Worker            rv()
3408*da0073e9SAndroid Build Coastguard Worker            rv()
3409*da0073e9SAndroid Build Coastguard Worker            return rv
3410*da0073e9SAndroid Build Coastguard Worker
3411*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(9000)
3412*da0073e9SAndroid Build Coastguard Worker        counter1 = fn1()
3413*da0073e9SAndroid Build Coastguard Worker        result1 = [counter1(), counter1(), counter1()]
3414*da0073e9SAndroid Build Coastguard Worker
3415*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(9000)
3416*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3417*da0073e9SAndroid Build Coastguard Worker        opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1)
3418*da0073e9SAndroid Build Coastguard Worker        counter2 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
3419*da0073e9SAndroid Build Coastguard Worker        result2 = [counter2(), counter2(), counter2()]
3420*da0073e9SAndroid Build Coastguard Worker        result1.append(counter1())
3421*da0073e9SAndroid Build Coastguard Worker        result2.append(counter2())
3422*da0073e9SAndroid Build Coastguard Worker
3423*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(result1, result2))
3424*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
3425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 11)
3426*da0073e9SAndroid Build Coastguard Worker
3427*da0073e9SAndroid Build Coastguard Worker    def test_write_to_closures_in_inlining(self):
3428*da0073e9SAndroid Build Coastguard Worker        out = []
3429*da0073e9SAndroid Build Coastguard Worker        for use_dynamo in [False, True]:
3430*da0073e9SAndroid Build Coastguard Worker
3431*da0073e9SAndroid Build Coastguard Worker            def make_counter():
3432*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(10)
3433*da0073e9SAndroid Build Coastguard Worker
3434*da0073e9SAndroid Build Coastguard Worker                def counter():
3435*da0073e9SAndroid Build Coastguard Worker                    nonlocal x
3436*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
3437*da0073e9SAndroid Build Coastguard Worker                    return x
3438*da0073e9SAndroid Build Coastguard Worker
3439*da0073e9SAndroid Build Coastguard Worker                return counter
3440*da0073e9SAndroid Build Coastguard Worker
3441*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(0)
3442*da0073e9SAndroid Build Coastguard Worker            counter = make_counter()
3443*da0073e9SAndroid Build Coastguard Worker            if not use_dynamo:
3444*da0073e9SAndroid Build Coastguard Worker                out.append(counter() + counter())
3445*da0073e9SAndroid Build Coastguard Worker            else:
3446*da0073e9SAndroid Build Coastguard Worker                cnts = torch._dynamo.testing.CompileCounter()
3447*da0073e9SAndroid Build Coastguard Worker
3448*da0073e9SAndroid Build Coastguard Worker                @torch._dynamo.optimize(cnts, nopython=True)
3449*da0073e9SAndroid Build Coastguard Worker                def fn(counter):
3450*da0073e9SAndroid Build Coastguard Worker                    return counter() + counter()
3451*da0073e9SAndroid Build Coastguard Worker
3452*da0073e9SAndroid Build Coastguard Worker                out.append(fn(counter))
3453*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cnts.frame_count, 1)
3454*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cnts.op_count, 3)
3455*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(same(counter() + counter(), out[-1]))
3456*da0073e9SAndroid Build Coastguard Worker
3457*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(out[0], out[1]))
3458*da0073e9SAndroid Build Coastguard Worker
3459*da0073e9SAndroid Build Coastguard Worker    def test_closure_out_of_scope_cell(self):
3460*da0073e9SAndroid Build Coastguard Worker        cell1 = torch.rand(1).item()
3461*da0073e9SAndroid Build Coastguard Worker        cell2 = torch.rand(3, 3)
3462*da0073e9SAndroid Build Coastguard Worker
3463*da0073e9SAndroid Build Coastguard Worker        def indirect():
3464*da0073e9SAndroid Build Coastguard Worker            return direct()
3465*da0073e9SAndroid Build Coastguard Worker
3466*da0073e9SAndroid Build Coastguard Worker        def direct():
3467*da0073e9SAndroid Build Coastguard Worker            def inner():
3468*da0073e9SAndroid Build Coastguard Worker                return cell1 + 1, cell2 + 3
3469*da0073e9SAndroid Build Coastguard Worker
3470*da0073e9SAndroid Build Coastguard Worker            return inner()
3471*da0073e9SAndroid Build Coastguard Worker
3472*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3473*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(indirect)
3474*da0073e9SAndroid Build Coastguard Worker        result1, result2 = opt_fn()
3475*da0073e9SAndroid Build Coastguard Worker        self.assertAlmostEqual(cell1 + 1, result1)
3476*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(cell2 + 3, result2))
3477*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 1)
3479*da0073e9SAndroid Build Coastguard Worker
3480*da0073e9SAndroid Build Coastguard Worker    def test_closure_out_of_scope_cell_with_mutation(self):
3481*da0073e9SAndroid Build Coastguard Worker        cell1 = torch.rand(1).item()
3482*da0073e9SAndroid Build Coastguard Worker        orig1 = cell1
3483*da0073e9SAndroid Build Coastguard Worker        cell2 = torch.rand(3, 3)
3484*da0073e9SAndroid Build Coastguard Worker        orig2 = cell2.clone()
3485*da0073e9SAndroid Build Coastguard Worker
3486*da0073e9SAndroid Build Coastguard Worker        def indirect():
3487*da0073e9SAndroid Build Coastguard Worker            return direct()
3488*da0073e9SAndroid Build Coastguard Worker
3489*da0073e9SAndroid Build Coastguard Worker        def direct():
3490*da0073e9SAndroid Build Coastguard Worker            def inner():
3491*da0073e9SAndroid Build Coastguard Worker                nonlocal cell1, cell2
3492*da0073e9SAndroid Build Coastguard Worker                x = cell2 + 1
3493*da0073e9SAndroid Build Coastguard Worker                cell1 += 1
3494*da0073e9SAndroid Build Coastguard Worker                cell2 += 10
3495*da0073e9SAndroid Build Coastguard Worker                x = x + cell2
3496*da0073e9SAndroid Build Coastguard Worker                return cell1, cell2, x
3497*da0073e9SAndroid Build Coastguard Worker
3498*da0073e9SAndroid Build Coastguard Worker            return inner()
3499*da0073e9SAndroid Build Coastguard Worker
3500*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3501*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(indirect)
3502*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 4):
3503*da0073e9SAndroid Build Coastguard Worker            result1, result2, _ = opt_fn()
3504*da0073e9SAndroid Build Coastguard Worker            self.assertAlmostEqual(orig1 + 1 * i, result1)
3505*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(orig2 + 10 * i, result2))
3506*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 1)
3507*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.op_count, 3)
3508*da0073e9SAndroid Build Coastguard Worker            cnts.clear()
3509*da0073e9SAndroid Build Coastguard Worker
3510*da0073e9SAndroid Build Coastguard Worker    def test_closure_with_mutation_and_graph_break(self):
3511*da0073e9SAndroid Build Coastguard Worker        def fn():
3512*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(1)
3513*da0073e9SAndroid Build Coastguard Worker
3514*da0073e9SAndroid Build Coastguard Worker            def subfunc():
3515*da0073e9SAndroid Build Coastguard Worker                x[0] = backup
3516*da0073e9SAndroid Build Coastguard Worker
3517*da0073e9SAndroid Build Coastguard Worker            if x[0] >= -1e5:
3518*da0073e9SAndroid Build Coastguard Worker                pass
3519*da0073e9SAndroid Build Coastguard Worker
3520*da0073e9SAndroid Build Coastguard Worker            backup = 1
3521*da0073e9SAndroid Build Coastguard Worker            subfunc()
3522*da0073e9SAndroid Build Coastguard Worker            return x
3523*da0073e9SAndroid Build Coastguard Worker
3524*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3525*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
3526*da0073e9SAndroid Build Coastguard Worker        expected = fn()
3527*da0073e9SAndroid Build Coastguard Worker        actual = opt_fn()
3528*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(expected, actual))
3529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
3530*da0073e9SAndroid Build Coastguard Worker
3531*da0073e9SAndroid Build Coastguard Worker    def test_closure_out_of_scope_cell_with_cond(self):
3532*da0073e9SAndroid Build Coastguard Worker        # Test closure with out-of-scope cell variable, used in a cond
3533*da0073e9SAndroid Build Coastguard Worker        # where the two branches read different closure variables
3534*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import cond
3535*da0073e9SAndroid Build Coastguard Worker
3536*da0073e9SAndroid Build Coastguard Worker        def g(x):
3537*da0073e9SAndroid Build Coastguard Worker            return x
3538*da0073e9SAndroid Build Coastguard Worker
3539*da0073e9SAndroid Build Coastguard Worker        class ModuleCondDeep(torch.nn.Module):
3540*da0073e9SAndroid Build Coastguard Worker            def forward(self, pred, x):
3541*da0073e9SAndroid Build Coastguard Worker                return self._indirection(pred, x)
3542*da0073e9SAndroid Build Coastguard Worker
3543*da0073e9SAndroid Build Coastguard Worker            def _indirection(self, pred, x):
3544*da0073e9SAndroid Build Coastguard Worker                return self.indirection(pred, x)
3545*da0073e9SAndroid Build Coastguard Worker
3546*da0073e9SAndroid Build Coastguard Worker            def indirection(self, pred, x):
3547*da0073e9SAndroid Build Coastguard Worker                def true_fn(y):
3548*da0073e9SAndroid Build Coastguard Worker                    return y + 2
3549*da0073e9SAndroid Build Coastguard Worker
3550*da0073e9SAndroid Build Coastguard Worker                def false_fn(y):
3551*da0073e9SAndroid Build Coastguard Worker                    return y - 2
3552*da0073e9SAndroid Build Coastguard Worker
3553*da0073e9SAndroid Build Coastguard Worker                def shallow(x):
3554*da0073e9SAndroid Build Coastguard Worker                    return x * 2
3555*da0073e9SAndroid Build Coastguard Worker
3556*da0073e9SAndroid Build Coastguard Worker                def deep(x):
3557*da0073e9SAndroid Build Coastguard Worker                    # y = g(x)
3558*da0073e9SAndroid Build Coastguard Worker                    y = x
3559*da0073e9SAndroid Build Coastguard Worker                    return cond(
3560*da0073e9SAndroid Build Coastguard Worker                        x[0][0] > 0,
3561*da0073e9SAndroid Build Coastguard Worker                        true_fn,
3562*da0073e9SAndroid Build Coastguard Worker                        false_fn,
3563*da0073e9SAndroid Build Coastguard Worker                        [y],
3564*da0073e9SAndroid Build Coastguard Worker                    )
3565*da0073e9SAndroid Build Coastguard Worker
3566*da0073e9SAndroid Build Coastguard Worker                return cond(pred, shallow, deep, [x])
3567*da0073e9SAndroid Build Coastguard Worker
3568*da0073e9SAndroid Build Coastguard Worker        mod = ModuleCondDeep()
3569*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager")(mod)
3570*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
3571*da0073e9SAndroid Build Coastguard Worker        exp1 = mod(torch.tensor(False), inp)
3572*da0073e9SAndroid Build Coastguard Worker        actual1 = opt_mod(torch.tensor(False), inp)
3573*da0073e9SAndroid Build Coastguard Worker        exp2 = mod(torch.tensor(True), inp)
3574*da0073e9SAndroid Build Coastguard Worker        actual2 = opt_mod(torch.tensor(True), inp)
3575*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(exp1, actual1))
3576*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(exp2, actual2))
3577*da0073e9SAndroid Build Coastguard Worker
3578*da0073e9SAndroid Build Coastguard Worker    def test_top_package_import(self):
3579*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3580*da0073e9SAndroid Build Coastguard Worker            import torch.fx
3581*da0073e9SAndroid Build Coastguard Worker
3582*da0073e9SAndroid Build Coastguard Worker            assert not isinstance(x, torch.fx.Proxy)
3583*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
3584*da0073e9SAndroid Build Coastguard Worker
3585*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 5)
3586*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
3587*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3588*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
3589*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
3590*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3591*da0073e9SAndroid Build Coastguard Worker
3592*da0073e9SAndroid Build Coastguard Worker    def test_typing_typevar(self):
3593*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3594*da0073e9SAndroid Build Coastguard Worker            def sumt(y: torch.Tensor) -> torch.Tensor:
3595*da0073e9SAndroid Build Coastguard Worker                return torch.sum(y)
3596*da0073e9SAndroid Build Coastguard Worker
3597*da0073e9SAndroid Build Coastguard Worker            def foo(c: typing.Callable[[T], T], y: T) -> T:
3598*da0073e9SAndroid Build Coastguard Worker                return c(y)
3599*da0073e9SAndroid Build Coastguard Worker
3600*da0073e9SAndroid Build Coastguard Worker            return foo(sumt, x)
3601*da0073e9SAndroid Build Coastguard Worker
3602*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3603*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
3604*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3605*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
3606*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
3607*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3608*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3609*da0073e9SAndroid Build Coastguard Worker
3610*da0073e9SAndroid Build Coastguard Worker    def test_typing_union_and_optional(self):
3611*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3612*da0073e9SAndroid Build Coastguard Worker            a = torch.jit.annotate(typing.Dict[str, typing.Optional[torch.Tensor]], {})
3613*da0073e9SAndroid Build Coastguard Worker            b = torch.jit.annotate(
3614*da0073e9SAndroid Build Coastguard Worker                typing.Dict[str, typing.Union[torch.Tensor, None]], {}
3615*da0073e9SAndroid Build Coastguard Worker            )
3616*da0073e9SAndroid Build Coastguard Worker            return a, b, x + 1
3617*da0073e9SAndroid Build Coastguard Worker
3618*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
3619*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
3620*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=False)(fn)
3621*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
3622*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3623*da0073e9SAndroid Build Coastguard Worker
3624*da0073e9SAndroid Build Coastguard Worker    def test_optimize_on_module(self):
3625*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
3626*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
3627*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3628*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
3629*da0073e9SAndroid Build Coastguard Worker
3630*da0073e9SAndroid Build Coastguard Worker            def custom_member(self):
3631*da0073e9SAndroid Build Coastguard Worker                # Just for checking that Dynamo returned mod object can redirect
3632*da0073e9SAndroid Build Coastguard Worker                # to this method
3633*da0073e9SAndroid Build Coastguard Worker                pass
3634*da0073e9SAndroid Build Coastguard Worker
3635*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3636*da0073e9SAndroid Build Coastguard Worker                return self.relu(x)
3637*da0073e9SAndroid Build Coastguard Worker
3638*da0073e9SAndroid Build Coastguard Worker        cnts1 = torch._dynamo.testing.CompileCounter()
3639*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
3640*da0073e9SAndroid Build Coastguard Worker        optimized_mod = torch._dynamo.optimize(cnts1, nopython=True)(mod)
3641*da0073e9SAndroid Build Coastguard Worker
3642*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10)
3643*da0073e9SAndroid Build Coastguard Worker        ref = mod(a)
3644*da0073e9SAndroid Build Coastguard Worker        res = optimized_mod(a)
3645*da0073e9SAndroid Build Coastguard Worker
3646*da0073e9SAndroid Build Coastguard Worker        optimized_mod.custom_member()
3647*da0073e9SAndroid Build Coastguard Worker
3648*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3649*da0073e9SAndroid Build Coastguard Worker
3650*da0073e9SAndroid Build Coastguard Worker    def test_nested_optimize_decorator(self):
3651*da0073e9SAndroid Build Coastguard Worker        cnts2 = torch._dynamo.testing.CompileCounter()
3652*da0073e9SAndroid Build Coastguard Worker        cnts3 = torch._dynamo.testing.CompileCounter()
3653*da0073e9SAndroid Build Coastguard Worker
3654*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.run()
3655*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
3656*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x) * 10
3657*da0073e9SAndroid Build Coastguard Worker
3658*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnts2, nopython=True)
3659*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
3660*da0073e9SAndroid Build Coastguard Worker            return fn1(x) + 1
3661*da0073e9SAndroid Build Coastguard Worker
3662*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnts3, nopython=True)
3663*da0073e9SAndroid Build Coastguard Worker        def fn3(x):
3664*da0073e9SAndroid Build Coastguard Worker            return torch.relu(fn2(x))
3665*da0073e9SAndroid Build Coastguard Worker
3666*da0073e9SAndroid Build Coastguard Worker        fn3(torch.randn(4, 5))
3667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts2.frame_count, 0)
3668*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts3.frame_count, 1)
3669*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts3.op_count, 4)
3670*da0073e9SAndroid Build Coastguard Worker
3671*da0073e9SAndroid Build Coastguard Worker    def test_nested_optimize_run(self):
3672*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3673*da0073e9SAndroid Build Coastguard Worker
3674*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnts, nopython=True)
3675*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3676*da0073e9SAndroid Build Coastguard Worker            return torch.relu(torch.cos(x) + torch.sin(x))
3677*da0073e9SAndroid Build Coastguard Worker
3678*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4))
3679*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3680*da0073e9SAndroid Build Coastguard Worker
3681*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4, 4))
3682*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
3683*da0073e9SAndroid Build Coastguard Worker
3684*da0073e9SAndroid Build Coastguard Worker        # Test that run works on a decorated fn
3685*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.run(fn)
3686*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4, 4, 4))
3687*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
3688*da0073e9SAndroid Build Coastguard Worker
3689*da0073e9SAndroid Build Coastguard Worker    def test_nested_optimize(self):
3690*da0073e9SAndroid Build Coastguard Worker        cnts1 = torch._dynamo.testing.CompileCounter()
3691*da0073e9SAndroid Build Coastguard Worker        cnts2 = torch._dynamo.testing.CompileCounter()
3692*da0073e9SAndroid Build Coastguard Worker
3693*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3694*da0073e9SAndroid Build Coastguard Worker            return torch.relu(torch.cos(x) + torch.sin(x))
3695*da0073e9SAndroid Build Coastguard Worker
3696*da0073e9SAndroid Build Coastguard Worker        fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn)
3697*da0073e9SAndroid Build Coastguard Worker        fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1)
3698*da0073e9SAndroid Build Coastguard Worker
3699*da0073e9SAndroid Build Coastguard Worker        # The first optimize in the nesting should be ignored
3700*da0073e9SAndroid Build Coastguard Worker        fn2(torch.randn(4))
3701*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts2.frame_count, 1)
3702*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts1.frame_count, 0)
3703*da0073e9SAndroid Build Coastguard Worker
3704*da0073e9SAndroid Build Coastguard Worker        # Since the fn code object is already compiled, calling fn1 should
3705*da0073e9SAndroid Build Coastguard Worker        # directly call the compiled_fn callable.
3706*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.run()(fn1)(torch.randn(4))
3707*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts1.frame_count, 0)
3708*da0073e9SAndroid Build Coastguard Worker
3709*da0073e9SAndroid Build Coastguard Worker        # Test same behavior by reversing the calls
3710*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
3711*da0073e9SAndroid Build Coastguard Worker        cnts1 = torch._dynamo.testing.CompileCounter()
3712*da0073e9SAndroid Build Coastguard Worker        cnts2 = torch._dynamo.testing.CompileCounter()
3713*da0073e9SAndroid Build Coastguard Worker        fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn)
3714*da0073e9SAndroid Build Coastguard Worker        fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1)
3715*da0073e9SAndroid Build Coastguard Worker        fn1(torch.randn(4))
3716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts1.frame_count, 1)
3717*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.run()(fn2)(torch.randn(4))
3718*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts2.frame_count, 0)
3719*da0073e9SAndroid Build Coastguard Worker
3720*da0073e9SAndroid Build Coastguard Worker    def test_torch_size(self):
3721*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3722*da0073e9SAndroid Build Coastguard Worker
3723*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3724*da0073e9SAndroid Build Coastguard Worker            output_size = torch.Size([10, 10])
3725*da0073e9SAndroid Build Coastguard Worker            x = x.view(*output_size)
3726*da0073e9SAndroid Build Coastguard Worker            return (x,)
3727*da0073e9SAndroid Build Coastguard Worker
3728*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(100, requires_grad=True)
3729*da0073e9SAndroid Build Coastguard Worker        x_clone = x.clone()
3730*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
3731*da0073e9SAndroid Build Coastguard Worker
3732*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3733*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x_clone)
3734*da0073e9SAndroid Build Coastguard Worker
3735*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3736*da0073e9SAndroid Build Coastguard Worker
3737*da0073e9SAndroid Build Coastguard Worker    def test_torch_size_numel(self):
3738*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3739*da0073e9SAndroid Build Coastguard Worker
3740*da0073e9SAndroid Build Coastguard Worker        def fn():
3741*da0073e9SAndroid Build Coastguard Worker            return torch.Size([10, 8]).numel()
3742*da0073e9SAndroid Build Coastguard Worker
3743*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3744*da0073e9SAndroid Build Coastguard Worker        num = torch.Size([10, 8]).numel()
3745*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(), num)
3746*da0073e9SAndroid Build Coastguard Worker
3747*da0073e9SAndroid Build Coastguard Worker    def test_torch_size_numel_dynamic(self):
3748*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3749*da0073e9SAndroid Build Coastguard Worker
3750*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3751*da0073e9SAndroid Build Coastguard Worker            return x.size().numel()
3752*da0073e9SAndroid Build Coastguard Worker
3753*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3754*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10, 1, 8, 1)
3755*da0073e9SAndroid Build Coastguard Worker        expect = fn(x)
3756*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x), expect)
3757*da0073e9SAndroid Build Coastguard Worker
3758*da0073e9SAndroid Build Coastguard Worker    def test_shape_type(self):
3759*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3760*da0073e9SAndroid Build Coastguard Worker
3761*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3762*da0073e9SAndroid Build Coastguard Worker            return x + (type(x.shape) == torch.Size)
3763*da0073e9SAndroid Build Coastguard Worker
3764*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3765*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(())
3766*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x), fn(x))
3767*da0073e9SAndroid Build Coastguard Worker
3768*da0073e9SAndroid Build Coastguard Worker    def test_size_dim(self):
3769*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3770*da0073e9SAndroid Build Coastguard Worker
3771*da0073e9SAndroid Build Coastguard Worker        def fn(x, dim):
3772*da0073e9SAndroid Build Coastguard Worker            return x.size(dim=dim)
3773*da0073e9SAndroid Build Coastguard Worker
3774*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3775*da0073e9SAndroid Build Coastguard Worker        x = torch.empty([4, 9, 8])
3776*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, 1), 9)
3777*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, -2), 9)
3778*da0073e9SAndroid Build Coastguard Worker
3779*da0073e9SAndroid Build Coastguard Worker    def test_stride_dim(self):
3780*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3781*da0073e9SAndroid Build Coastguard Worker
3782*da0073e9SAndroid Build Coastguard Worker        def fn(x, dim):
3783*da0073e9SAndroid Build Coastguard Worker            return x.stride(dim=dim)
3784*da0073e9SAndroid Build Coastguard Worker
3785*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3786*da0073e9SAndroid Build Coastguard Worker        x = torch.empty([4, 9, 8])
3787*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, 0), 72)
3788*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, -2), 8)
3789*da0073e9SAndroid Build Coastguard Worker
3790*da0073e9SAndroid Build Coastguard Worker    def test_torch_seed(self):
3791*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.utils import counters
3792*da0073e9SAndroid Build Coastguard Worker
3793*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3794*da0073e9SAndroid Build Coastguard Worker        counters.clear()
3795*da0073e9SAndroid Build Coastguard Worker
3796*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3797*da0073e9SAndroid Build Coastguard Worker            attention_seed = int(torch.seed() % sys.maxsize)
3798*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(attention_seed)
3799*da0073e9SAndroid Build Coastguard Worker            return (x,)
3800*da0073e9SAndroid Build Coastguard Worker
3801*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, requires_grad=True)
3802*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
3803*da0073e9SAndroid Build Coastguard Worker
3804*da0073e9SAndroid Build Coastguard Worker        # Python code is needed here, since torch.manual_seed graph-breaks.
3805*da0073e9SAndroid Build Coastguard Worker        # Refs: https://github.com/pytorch/pytorch/issues/107187
3806*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
3807*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
3808*da0073e9SAndroid Build Coastguard Worker
3809*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3810*da0073e9SAndroid Build Coastguard Worker        # Only the torch.seed call is turned into an FX graph.
3811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 1)
3812*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3813*da0073e9SAndroid Build Coastguard Worker        # Graph breaks at manual_seed.
3814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 1)
3815*da0073e9SAndroid Build Coastguard Worker
3816*da0073e9SAndroid Build Coastguard Worker    def test_is_tensor_like(self):
3817*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3818*da0073e9SAndroid Build Coastguard Worker
3819*da0073e9SAndroid Build Coastguard Worker        def f(x):
3820*da0073e9SAndroid Build Coastguard Worker            if torch.overrides.is_tensor_like(x):
3821*da0073e9SAndroid Build Coastguard Worker                return (x * 2,)
3822*da0073e9SAndroid Build Coastguard Worker            return (torch.ones(10) + x,)
3823*da0073e9SAndroid Build Coastguard Worker
3824*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
3825*da0073e9SAndroid Build Coastguard Worker        ref0 = f(x)
3826*da0073e9SAndroid Build Coastguard Worker        ref1 = f(4)
3827*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize(cnts, nopython=True)(f)
3828*da0073e9SAndroid Build Coastguard Worker        res0 = opt_f(x)
3829*da0073e9SAndroid Build Coastguard Worker        res1 = opt_f(4)
3830*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref0, res0))
3831*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref1, res1))
3832*da0073e9SAndroid Build Coastguard Worker
3833*da0073e9SAndroid Build Coastguard Worker    def test_is_tensor_like2(self):
3834*da0073e9SAndroid Build Coastguard Worker        class MyTensor:
3835*da0073e9SAndroid Build Coastguard Worker            @classmethod
3836*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
3837*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
3838*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
3839*da0073e9SAndroid Build Coastguard Worker
3840*da0073e9SAndroid Build Coastguard Worker                if func is torch.max:
3841*da0073e9SAndroid Build Coastguard Worker                    return torch.tensor(123)
3842*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
3843*da0073e9SAndroid Build Coastguard Worker
3844*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3845*da0073e9SAndroid Build Coastguard Worker            if torch.overrides.is_tensor_like(x):
3846*da0073e9SAndroid Build Coastguard Worker                return torch.max(x)
3847*da0073e9SAndroid Build Coastguard Worker            else:
3848*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(1)
3849*da0073e9SAndroid Build Coastguard Worker
3850*da0073e9SAndroid Build Coastguard Worker        x = MyTensor()
3851*da0073e9SAndroid Build Coastguard Worker        ref0 = fn(x)
3852*da0073e9SAndroid Build Coastguard Worker        ref1 = fn(4)
3853*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
3854*da0073e9SAndroid Build Coastguard Worker        res0 = opt_fn(x)
3855*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(4)
3856*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref0, res0))
3857*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref1, res1))
3858*da0073e9SAndroid Build Coastguard Worker
3859*da0073e9SAndroid Build Coastguard Worker    def test_tensor_data(self):
3860*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
3861*da0073e9SAndroid Build Coastguard Worker            return x[y.data]
3862*da0073e9SAndroid Build Coastguard Worker
3863*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(8)
3864*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(8).to(torch.int)
3865*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
3866*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
3867*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
3868*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3869*da0073e9SAndroid Build Coastguard Worker
3870*da0073e9SAndroid Build Coastguard Worker    def test_tensor_layout(self):
3871*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3872*da0073e9SAndroid Build Coastguard Worker            return torch.zeros(
3873*da0073e9SAndroid Build Coastguard Worker                [x.size()[0], x.size()[1]],
3874*da0073e9SAndroid Build Coastguard Worker                dtype=x.dtype,
3875*da0073e9SAndroid Build Coastguard Worker                layout=x.layout,
3876*da0073e9SAndroid Build Coastguard Worker                device=x.device,
3877*da0073e9SAndroid Build Coastguard Worker            )
3878*da0073e9SAndroid Build Coastguard Worker
3879*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3)
3880*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
3881*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
3882*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
3883*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3884*da0073e9SAndroid Build Coastguard Worker
3885*da0073e9SAndroid Build Coastguard Worker    def test_version_ci(self):
3886*da0073e9SAndroid Build Coastguard Worker        # temporary test to check that the ci torch version is set correctly
3887*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(torch, "_subclasses"))
3888*da0073e9SAndroid Build Coastguard Worker
3889*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires cuda")
3890*da0073e9SAndroid Build Coastguard Worker    def test_rand(self):
3891*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3892*da0073e9SAndroid Build Coastguard Worker        device = "cuda"
3893*da0073e9SAndroid Build Coastguard Worker
3894*da0073e9SAndroid Build Coastguard Worker        def fn():
3895*da0073e9SAndroid Build Coastguard Worker            return torch.randn(10, device=device)
3896*da0073e9SAndroid Build Coastguard Worker
3897*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(10)
3898*da0073e9SAndroid Build Coastguard Worker        ref_run1 = fn()
3899*da0073e9SAndroid Build Coastguard Worker
3900*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(10)
3901*da0073e9SAndroid Build Coastguard Worker        ref_run2 = fn()
3902*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref_run1, ref_run2))
3903*da0073e9SAndroid Build Coastguard Worker
3904*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(10)
3905*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3906*da0073e9SAndroid Build Coastguard Worker        res = opt_fn()
3907*da0073e9SAndroid Build Coastguard Worker
3908*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res, ref_run1))
3909*da0073e9SAndroid Build Coastguard Worker
3910*da0073e9SAndroid Build Coastguard Worker    def test_slice_input(self):
3911*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3912*da0073e9SAndroid Build Coastguard Worker
3913*da0073e9SAndroid Build Coastguard Worker        def getitem(a, idx):
3914*da0073e9SAndroid Build Coastguard Worker            if isinstance(idx, slice):
3915*da0073e9SAndroid Build Coastguard Worker                return (
3916*da0073e9SAndroid Build Coastguard Worker                    torch.zeros(1),
3917*da0073e9SAndroid Build Coastguard Worker                    a[idx]
3918*da0073e9SAndroid Build Coastguard Worker                    + [
3919*da0073e9SAndroid Build Coastguard Worker                        100,
3920*da0073e9SAndroid Build Coastguard Worker                    ],
3921*da0073e9SAndroid Build Coastguard Worker                )
3922*da0073e9SAndroid Build Coastguard Worker            else:
3923*da0073e9SAndroid Build Coastguard Worker                return (torch.zeros(1), a[idx])
3924*da0073e9SAndroid Build Coastguard Worker
3925*da0073e9SAndroid Build Coastguard Worker        layers = list(range(10))
3926*da0073e9SAndroid Build Coastguard Worker        ref0 = getitem(layers, slice(0, 2, 1))
3927*da0073e9SAndroid Build Coastguard Worker        ref1 = getitem(layers, 2)
3928*da0073e9SAndroid Build Coastguard Worker        ref2 = getitem(layers, slice(3, 8, 2))
3929*da0073e9SAndroid Build Coastguard Worker        opt_getitem = torch._dynamo.optimize(cnts, nopython=True)(getitem)
3930*da0073e9SAndroid Build Coastguard Worker        res0 = opt_getitem(layers, slice(0, 2, 1))
3931*da0073e9SAndroid Build Coastguard Worker        res1 = opt_getitem(layers, 2)
3932*da0073e9SAndroid Build Coastguard Worker        res2 = opt_getitem(layers, slice(3, 8, 2))
3933*da0073e9SAndroid Build Coastguard Worker
3934*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref0 == res0)
3935*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref1 == res1)
3936*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref2 == res2)
3937*da0073e9SAndroid Build Coastguard Worker
3938*da0073e9SAndroid Build Coastguard Worker    def test_grad(self):
3939*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3940*da0073e9SAndroid Build Coastguard Worker
3941*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
3942*da0073e9SAndroid Build Coastguard Worker            out = a * b
3943*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
3944*da0073e9SAndroid Build Coastguard Worker            real_out = torch.sigmoid(a.grad + b)
3945*da0073e9SAndroid Build Coastguard Worker            return real_out
3946*da0073e9SAndroid Build Coastguard Worker
3947*da0073e9SAndroid Build Coastguard Worker        inps = [torch.randn(4, requires_grad=True) for _ in range(2)]
3948*da0073e9SAndroid Build Coastguard Worker        for inp in inps:
3949*da0073e9SAndroid Build Coastguard Worker            inp.grad = None
3950*da0073e9SAndroid Build Coastguard Worker        ref = fn(*inps)
3951*da0073e9SAndroid Build Coastguard Worker
3952*da0073e9SAndroid Build Coastguard Worker        for inp in inps:
3953*da0073e9SAndroid Build Coastguard Worker            inp.grad = None
3954*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
3955*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(*inps)
3956*da0073e9SAndroid Build Coastguard Worker
3957*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3958*da0073e9SAndroid Build Coastguard Worker
3959*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
3960*da0073e9SAndroid Build Coastguard Worker    def test_source_non_input_grad_access(self):
3961*da0073e9SAndroid Build Coastguard Worker        # This test creates a model, and accesses the grads
3962*da0073e9SAndroid Build Coastguard Worker        # from its parameter. This means that within dynamo,
3963*da0073e9SAndroid Build Coastguard Worker        # the tensor we are reading the grad from HAS a source,
3964*da0073e9SAndroid Build Coastguard Worker        # but is not known to graphargs.
3965*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3966*da0073e9SAndroid Build Coastguard Worker
3967*da0073e9SAndroid Build Coastguard Worker        class TrivialModel(torch.nn.Module):
3968*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
3969*da0073e9SAndroid Build Coastguard Worker                super(TrivialModel, self).__init__()
3970*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(2, 1)
3971*da0073e9SAndroid Build Coastguard Worker
3972*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3973*da0073e9SAndroid Build Coastguard Worker                return self.linear(x)
3974*da0073e9SAndroid Build Coastguard Worker
3975*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
3976*da0073e9SAndroid Build Coastguard Worker            outs = []
3977*da0073e9SAndroid Build Coastguard Worker            for param in model.parameters():
3978*da0073e9SAndroid Build Coastguard Worker                outs.append(torch.ones(param.grad.size()))
3979*da0073e9SAndroid Build Coastguard Worker            return outs, param.grad + 1
3980*da0073e9SAndroid Build Coastguard Worker
3981*da0073e9SAndroid Build Coastguard Worker        model = TrivialModel()
3982*da0073e9SAndroid Build Coastguard Worker        # Eager
3983*da0073e9SAndroid Build Coastguard Worker        a = torch.ones([2, 2], requires_grad=True)
3984*da0073e9SAndroid Build Coastguard Worker        b = torch.ones([2, 2])
3985*da0073e9SAndroid Build Coastguard Worker        out = model(a)
3986*da0073e9SAndroid Build Coastguard Worker        out_sum = out.sum()
3987*da0073e9SAndroid Build Coastguard Worker        out_sum.backward()
3988*da0073e9SAndroid Build Coastguard Worker        ref = fn(a, b)
3989*da0073e9SAndroid Build Coastguard Worker
3990*da0073e9SAndroid Build Coastguard Worker        # Compiled
3991*da0073e9SAndroid Build Coastguard Worker        model = TrivialModel()
3992*da0073e9SAndroid Build Coastguard Worker        a = torch.ones([2, 2], requires_grad=True)
3993*da0073e9SAndroid Build Coastguard Worker        b = torch.ones([2, 2])
3994*da0073e9SAndroid Build Coastguard Worker        out = model(a)
3995*da0073e9SAndroid Build Coastguard Worker        out_sum = out.sum()
3996*da0073e9SAndroid Build Coastguard Worker        out_sum.backward()
3997*da0073e9SAndroid Build Coastguard Worker
3998*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3999*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(a, b)
4000*da0073e9SAndroid Build Coastguard Worker
4001*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
4002*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
4003*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 3)
4004*da0073e9SAndroid Build Coastguard Worker
4005*da0073e9SAndroid Build Coastguard Worker    def test_intermediary_tensor_grad_access(self):
4006*da0073e9SAndroid Build Coastguard Worker        # This test creates a model, and accesses the grads
4007*da0073e9SAndroid Build Coastguard Worker        # from its parameters and an entirely intermediary tensor.
4008*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4009*da0073e9SAndroid Build Coastguard Worker
4010*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
4011*da0073e9SAndroid Build Coastguard Worker            intermediary = torch.ones(2, 2)
4012*da0073e9SAndroid Build Coastguard Worker            c = a + intermediary
4013*da0073e9SAndroid Build Coastguard Worker            outs = []
4014*da0073e9SAndroid Build Coastguard Worker            outs.append(intermediary.grad)
4015*da0073e9SAndroid Build Coastguard Worker            return outs
4016*da0073e9SAndroid Build Coastguard Worker
4017*da0073e9SAndroid Build Coastguard Worker        # Eager
4018*da0073e9SAndroid Build Coastguard Worker        a = torch.ones([2, 2], requires_grad=True)
4019*da0073e9SAndroid Build Coastguard Worker        b = torch.ones([2, 2])
4020*da0073e9SAndroid Build Coastguard Worker        ref = fn(a, b)
4021*da0073e9SAndroid Build Coastguard Worker
4022*da0073e9SAndroid Build Coastguard Worker        # Compiled
4023*da0073e9SAndroid Build Coastguard Worker        a = torch.ones([2, 2], requires_grad=True)
4024*da0073e9SAndroid Build Coastguard Worker        b = torch.ones([2, 2])
4025*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4026*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(a, b)
4027*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
4028*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
4029*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
4030*da0073e9SAndroid Build Coastguard Worker
4031*da0073e9SAndroid Build Coastguard Worker    def test_clone_sparse_input(self):
4032*da0073e9SAndroid Build Coastguard Worker        for layout in [
4033*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo,
4034*da0073e9SAndroid Build Coastguard Worker            torch.sparse_csr,
4035*da0073e9SAndroid Build Coastguard Worker            torch.sparse_csc,
4036*da0073e9SAndroid Build Coastguard Worker            torch.sparse_bsr,
4037*da0073e9SAndroid Build Coastguard Worker            torch.sparse_bsc,
4038*da0073e9SAndroid Build Coastguard Worker        ]:
4039*da0073e9SAndroid Build Coastguard Worker            for sparse_input in self.generate_simple_inputs(
4040*da0073e9SAndroid Build Coastguard Worker                layout,
4041*da0073e9SAndroid Build Coastguard Worker                device="cpu",
4042*da0073e9SAndroid Build Coastguard Worker                dtype=torch.float64,
4043*da0073e9SAndroid Build Coastguard Worker                index_dtype=torch.int64,
4044*da0073e9SAndroid Build Coastguard Worker            ):
4045*da0073e9SAndroid Build Coastguard Worker                # Invoke the dynamo clone input method directly.
4046*da0073e9SAndroid Build Coastguard Worker                sparse_copy = torch._dynamo.utils.clone_input(sparse_input)
4047*da0073e9SAndroid Build Coastguard Worker                # Make sure sparse clone is successful.
4048*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(sparse_input, sparse_copy)
4049*da0073e9SAndroid Build Coastguard Worker
4050*da0073e9SAndroid Build Coastguard Worker    def test_tensor_is_contiguous(self):
4051*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4052*da0073e9SAndroid Build Coastguard Worker            input = torch.randn((1, 16, 1, 1))
4053*da0073e9SAndroid Build Coastguard Worker            weight = torch.randn((8, 16, 3, 3))
4054*da0073e9SAndroid Build Coastguard Worker            weight = weight.to(memory_format=x)
4055*da0073e9SAndroid Build Coastguard Worker            output = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
4056*da0073e9SAndroid Build Coastguard Worker            return output.is_contiguous(memory_format=x)
4057*da0073e9SAndroid Build Coastguard Worker
4058*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
4059*da0073e9SAndroid Build Coastguard Worker        for x in [torch.contiguous_format, torch.channels_last]:
4060*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn(x), opt_fn(x))
4061*da0073e9SAndroid Build Coastguard Worker
4062*da0073e9SAndroid Build Coastguard Worker    def test_python_slice(self):
4063*da0073e9SAndroid Build Coastguard Worker        def f1(input):
4064*da0073e9SAndroid Build Coastguard Worker            y = 0
4065*da0073e9SAndroid Build Coastguard Worker            for i, x in enumerate(input[2:], 1):
4066*da0073e9SAndroid Build Coastguard Worker                y = y + x
4067*da0073e9SAndroid Build Coastguard Worker            return y
4068*da0073e9SAndroid Build Coastguard Worker
4069*da0073e9SAndroid Build Coastguard Worker        def f2(input):
4070*da0073e9SAndroid Build Coastguard Worker            y = 0
4071*da0073e9SAndroid Build Coastguard Worker            for i, x in enumerate(input.shape[2:], 1):
4072*da0073e9SAndroid Build Coastguard Worker                y = y + x
4073*da0073e9SAndroid Build Coastguard Worker            return y
4074*da0073e9SAndroid Build Coastguard Worker
4075*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4076*da0073e9SAndroid Build Coastguard Worker        opt_f1 = torch._dynamo.optimize(cnts)(f1)
4077*da0073e9SAndroid Build Coastguard Worker        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4078*da0073e9SAndroid Build Coastguard Worker        res1 = opt_f1([1, 2, 3, 5])
4079*da0073e9SAndroid Build Coastguard Worker        res2 = opt_f2(torch.rand([2, 3, 4, 5]))
4080*da0073e9SAndroid Build Coastguard Worker
4081*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, 8)
4082*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res2, 9)
4083*da0073e9SAndroid Build Coastguard Worker
4084*da0073e9SAndroid Build Coastguard Worker    def test_enum_as_dict_key(self):
4085*da0073e9SAndroid Build Coastguard Worker        class MyEnum(enum.Enum):
4086*da0073e9SAndroid Build Coastguard Worker            FOO = 10
4087*da0073e9SAndroid Build Coastguard Worker            BAR = 20
4088*da0073e9SAndroid Build Coastguard Worker
4089*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4090*da0073e9SAndroid Build Coastguard Worker            y = x + 2
4091*da0073e9SAndroid Build Coastguard Worker            z = {
4092*da0073e9SAndroid Build Coastguard Worker                MyEnum.FOO: torch.tensor(1),
4093*da0073e9SAndroid Build Coastguard Worker                MyEnum.BAR: 10,
4094*da0073e9SAndroid Build Coastguard Worker                "MyEnum.BAR": torch.tensor(8),
4095*da0073e9SAndroid Build Coastguard Worker                5: torch.rand(3),
4096*da0073e9SAndroid Build Coastguard Worker            }
4097*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
4098*da0073e9SAndroid Build Coastguard Worker            a = z[MyEnum.FOO] + z["MyEnum.BAR"]
4099*da0073e9SAndroid Build Coastguard Worker            b = y * 2
4100*da0073e9SAndroid Build Coastguard Worker            return a, b
4101*da0073e9SAndroid Build Coastguard Worker
4102*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4103*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
4104*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
4105*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(3)
4106*da0073e9SAndroid Build Coastguard Worker            ref = fn(x)
4107*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x)
4108*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref, res))
4109*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
4110*da0073e9SAndroid Build Coastguard Worker
4111*da0073e9SAndroid Build Coastguard Worker    def test_enum_as_dict_key_with_overloaded_str(self):
4112*da0073e9SAndroid Build Coastguard Worker        class MyEnum(enum.Enum):
4113*da0073e9SAndroid Build Coastguard Worker            FOO = 10
4114*da0073e9SAndroid Build Coastguard Worker            BAR = 20
4115*da0073e9SAndroid Build Coastguard Worker
4116*da0073e9SAndroid Build Coastguard Worker            def __str__(self):
4117*da0073e9SAndroid Build Coastguard Worker                return self.value
4118*da0073e9SAndroid Build Coastguard Worker
4119*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4120*da0073e9SAndroid Build Coastguard Worker            y = x + 2
4121*da0073e9SAndroid Build Coastguard Worker            z = {
4122*da0073e9SAndroid Build Coastguard Worker                MyEnum.FOO: torch.tensor(1),
4123*da0073e9SAndroid Build Coastguard Worker                MyEnum.BAR: 10,
4124*da0073e9SAndroid Build Coastguard Worker                "MyEnum.BAR": torch.tensor(8),
4125*da0073e9SAndroid Build Coastguard Worker                5: torch.rand(3),
4126*da0073e9SAndroid Build Coastguard Worker            }
4127*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
4128*da0073e9SAndroid Build Coastguard Worker            a = z[MyEnum.FOO] + z["MyEnum.BAR"]
4129*da0073e9SAndroid Build Coastguard Worker            b = y * 2
4130*da0073e9SAndroid Build Coastguard Worker            return a, b
4131*da0073e9SAndroid Build Coastguard Worker
4132*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4133*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
4134*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
4135*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(3)
4136*da0073e9SAndroid Build Coastguard Worker            ref = fn(x)
4137*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x)
4138*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref, res))
4139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
4140*da0073e9SAndroid Build Coastguard Worker
4141*da0073e9SAndroid Build Coastguard Worker    def test_const_dict_variable_python_type(self):
4142*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.variables import ConstantVariable, ConstDictVariable
4143*da0073e9SAndroid Build Coastguard Worker
4144*da0073e9SAndroid Build Coastguard Worker        make_key = ConstantVariable.create
4145*da0073e9SAndroid Build Coastguard Worker
4146*da0073e9SAndroid Build Coastguard Worker        d1 = {
4147*da0073e9SAndroid Build Coastguard Worker            make_key("a"): ConstantVariable.create(10),
4148*da0073e9SAndroid Build Coastguard Worker            make_key("b"): ConstantVariable.create(20),
4149*da0073e9SAndroid Build Coastguard Worker        }
4150*da0073e9SAndroid Build Coastguard Worker        d2 = collections.OrderedDict(
4151*da0073e9SAndroid Build Coastguard Worker            [
4152*da0073e9SAndroid Build Coastguard Worker                (make_key("x"), ConstantVariable.create(12)),
4153*da0073e9SAndroid Build Coastguard Worker                (make_key("y"), ConstantVariable.create(22)),
4154*da0073e9SAndroid Build Coastguard Worker            ]
4155*da0073e9SAndroid Build Coastguard Worker        )
4156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ConstDictVariable(d1).python_type(), dict)
4157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4158*da0073e9SAndroid Build Coastguard Worker            ConstDictVariable(d2, collections.OrderedDict).python_type(),
4159*da0073e9SAndroid Build Coastguard Worker            collections.OrderedDict,
4160*da0073e9SAndroid Build Coastguard Worker        )
4161*da0073e9SAndroid Build Coastguard Worker
4162*da0073e9SAndroid Build Coastguard Worker    def test_builtin_subclasses_as_method_on_class_type(self):
4163*da0073e9SAndroid Build Coastguard Worker        class Foo:
4164*da0073e9SAndroid Build Coastguard Worker            def __init__(self, name):
4165*da0073e9SAndroid Build Coastguard Worker                self.ame_ = name
4166*da0073e9SAndroid Build Coastguard Worker
4167*da0073e9SAndroid Build Coastguard Worker            def get_name(self):
4168*da0073e9SAndroid Build Coastguard Worker                return "Foo " + self.name_
4169*da0073e9SAndroid Build Coastguard Worker
4170*da0073e9SAndroid Build Coastguard Worker        class Bar(Foo):
4171*da0073e9SAndroid Build Coastguard Worker            def __init__(self, name):
4172*da0073e9SAndroid Build Coastguard Worker                self.name_ = name
4173*da0073e9SAndroid Build Coastguard Worker
4174*da0073e9SAndroid Build Coastguard Worker            def get_name(self):
4175*da0073e9SAndroid Build Coastguard Worker                return "Bar " + self.name_
4176*da0073e9SAndroid Build Coastguard Worker
4177*da0073e9SAndroid Build Coastguard Worker        class Baz(Foo):
4178*da0073e9SAndroid Build Coastguard Worker            def __init__(self, name):  # noqa: B903
4179*da0073e9SAndroid Build Coastguard Worker                self.name_ = name
4180*da0073e9SAndroid Build Coastguard Worker
4181*da0073e9SAndroid Build Coastguard Worker            def get_name(self):
4182*da0073e9SAndroid Build Coastguard Worker                return "Baz " + self.name_
4183*da0073e9SAndroid Build Coastguard Worker
4184*da0073e9SAndroid Build Coastguard Worker        subs_of_foo_reg = Foo.__subclasses__()
4185*da0073e9SAndroid Build Coastguard Worker
4186*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
4187*da0073e9SAndroid Build Coastguard Worker
4188*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize_assert(counter)
4189*da0073e9SAndroid Build Coastguard Worker        def fn():
4190*da0073e9SAndroid Build Coastguard Worker            return Foo.__subclasses__()
4191*da0073e9SAndroid Build Coastguard Worker
4192*da0073e9SAndroid Build Coastguard Worker        subs_of_foo_optim = fn()
4193*da0073e9SAndroid Build Coastguard Worker
4194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(subs_of_foo_reg), 2)
4195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(subs_of_foo_reg, subs_of_foo_optim)
4196*da0073e9SAndroid Build Coastguard Worker
4197*da0073e9SAndroid Build Coastguard Worker    def test_builtin_subclasses_as_method_on_var(self):
4198*da0073e9SAndroid Build Coastguard Worker        class Foo:
4199*da0073e9SAndroid Build Coastguard Worker            def __init__(self, name):
4200*da0073e9SAndroid Build Coastguard Worker                self.name_ = name
4201*da0073e9SAndroid Build Coastguard Worker
4202*da0073e9SAndroid Build Coastguard Worker            def get_name(self):
4203*da0073e9SAndroid Build Coastguard Worker                return "Foo " + self.name_
4204*da0073e9SAndroid Build Coastguard Worker
4205*da0073e9SAndroid Build Coastguard Worker        class Bar(Foo):
4206*da0073e9SAndroid Build Coastguard Worker            def __init__(self, name):
4207*da0073e9SAndroid Build Coastguard Worker                self.name_ = name
4208*da0073e9SAndroid Build Coastguard Worker
4209*da0073e9SAndroid Build Coastguard Worker            def get_name(self):
4210*da0073e9SAndroid Build Coastguard Worker                return "Bar " + self.name_
4211*da0073e9SAndroid Build Coastguard Worker
4212*da0073e9SAndroid Build Coastguard Worker        class Baz(Bar):
4213*da0073e9SAndroid Build Coastguard Worker            def __init__(self, name):
4214*da0073e9SAndroid Build Coastguard Worker                self.name_ = name
4215*da0073e9SAndroid Build Coastguard Worker
4216*da0073e9SAndroid Build Coastguard Worker            def get_name(self):
4217*da0073e9SAndroid Build Coastguard Worker                return "Baz " + self.name_
4218*da0073e9SAndroid Build Coastguard Worker
4219*da0073e9SAndroid Build Coastguard Worker        subs_of_foo_reg = Foo.__subclasses__()
4220*da0073e9SAndroid Build Coastguard Worker        sub_of_foo_subclass_var_reg = subs_of_foo_reg[0].__subclasses__()
4221*da0073e9SAndroid Build Coastguard Worker
4222*da0073e9SAndroid Build Coastguard Worker        sub_of_foo_subclass_var_optim = list()
4223*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
4224*da0073e9SAndroid Build Coastguard Worker
4225*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize_assert(counter)
4226*da0073e9SAndroid Build Coastguard Worker        def fn():
4227*da0073e9SAndroid Build Coastguard Worker            return Foo.__subclasses__()
4228*da0073e9SAndroid Build Coastguard Worker
4229*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize_assert(counter)
4230*da0073e9SAndroid Build Coastguard Worker        def fn_single(subs_of_foo_optim):
4231*da0073e9SAndroid Build Coastguard Worker            return subs_of_foo_optim[0].__subclasses__()
4232*da0073e9SAndroid Build Coastguard Worker
4233*da0073e9SAndroid Build Coastguard Worker        subs_of_foo_optim = fn()
4234*da0073e9SAndroid Build Coastguard Worker        sub_of_foo_subclass_var_optim = fn_single(subs_of_foo_optim)
4235*da0073e9SAndroid Build Coastguard Worker
4236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(sub_of_foo_subclass_var_optim), 1)
4237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sub_of_foo_subclass_var_optim, sub_of_foo_subclass_var_reg)
4238*da0073e9SAndroid Build Coastguard Worker
4239*da0073e9SAndroid Build Coastguard Worker    def test_builtin_str_on_user_defined_function(self):
4240*da0073e9SAndroid Build Coastguard Worker        def another_fn():
4241*da0073e9SAndroid Build Coastguard Worker            pass
4242*da0073e9SAndroid Build Coastguard Worker
4243*da0073e9SAndroid Build Coastguard Worker        def fn():
4244*da0073e9SAndroid Build Coastguard Worker            return "another_fn" in str(another_fn)
4245*da0073e9SAndroid Build Coastguard Worker
4246*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(nopython=True)(fn)
4247*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(opt_fn())
4248*da0073e9SAndroid Build Coastguard Worker
4249*da0073e9SAndroid Build Coastguard Worker    def test_enum_no_graphbreaks(self):
4250*da0073e9SAndroid Build Coastguard Worker        class Foo(enum.Enum):
4251*da0073e9SAndroid Build Coastguard Worker            FOO = 0
4252*da0073e9SAndroid Build Coastguard Worker            BAR = 1
4253*da0073e9SAndroid Build Coastguard Worker
4254*da0073e9SAndroid Build Coastguard Worker        def fn(x, foo):
4255*da0073e9SAndroid Build Coastguard Worker            if foo is Foo.FOO:
4256*da0073e9SAndroid Build Coastguard Worker                x = torch.add(x, 1.0)
4257*da0073e9SAndroid Build Coastguard Worker            x = torch.mul(x, 1.0)
4258*da0073e9SAndroid Build Coastguard Worker            return x
4259*da0073e9SAndroid Build Coastguard Worker
4260*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
4261*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4262*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4263*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, Foo.FOO)
4264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
4265*da0073e9SAndroid Build Coastguard Worker
4266*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
4267*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4268*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4269*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, Foo.BAR)
4270*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 1)
4271*da0073e9SAndroid Build Coastguard Worker
4272*da0073e9SAndroid Build Coastguard Worker    def test_repeat_interleave_graphbreaks(self):
4273*da0073e9SAndroid Build Coastguard Worker        def fn_no_breaks(x):
4274*da0073e9SAndroid Build Coastguard Worker            # no breaks on self_int
4275*da0073e9SAndroid Build Coastguard Worker            x += 1
4276*da0073e9SAndroid Build Coastguard Worker            x = torch.repeat_interleave(x, 2, 3)
4277*da0073e9SAndroid Build Coastguard Worker            x += 1
4278*da0073e9SAndroid Build Coastguard Worker            return x
4279*da0073e9SAndroid Build Coastguard Worker
4280*da0073e9SAndroid Build Coastguard Worker        def fn_has_breaks(x):
4281*da0073e9SAndroid Build Coastguard Worker            # breaks on self_Tensor
4282*da0073e9SAndroid Build Coastguard Worker            x += 1
4283*da0073e9SAndroid Build Coastguard Worker            x = torch.repeat_interleave(x, torch.tensor(2), 3)
4284*da0073e9SAndroid Build Coastguard Worker            x += 1
4285*da0073e9SAndroid Build Coastguard Worker            return x
4286*da0073e9SAndroid Build Coastguard Worker
4287*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([4, 16, 1, 64])
4288*da0073e9SAndroid Build Coastguard Worker
4289*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4290*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn_no_breaks)
4291*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
4292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
4293*da0073e9SAndroid Build Coastguard Worker
4294*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
4295*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4296*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn_has_breaks)
4297*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
4298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
4299*da0073e9SAndroid Build Coastguard Worker
4300*da0073e9SAndroid Build Coastguard Worker    def test_id_guarded_object(self):
4301*da0073e9SAndroid Build Coastguard Worker        class UDO:
4302*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend="eager")
4303*da0073e9SAndroid Build Coastguard Worker            def call(self, x, ref_id):
4304*da0073e9SAndroid Build Coastguard Worker                self_id = id(self)
4305*da0073e9SAndroid Build Coastguard Worker                if self_id == ref_id:
4306*da0073e9SAndroid Build Coastguard Worker                    x = torch.mul(x, 1.0)
4307*da0073e9SAndroid Build Coastguard Worker                else:
4308*da0073e9SAndroid Build Coastguard Worker                    x = torch.mul(x, 0)
4309*da0073e9SAndroid Build Coastguard Worker                return x
4310*da0073e9SAndroid Build Coastguard Worker
4311*da0073e9SAndroid Build Coastguard Worker        # Make sure we do recompile when id(self) is executed on
4312*da0073e9SAndroid Build Coastguard Worker        # different self objects.
4313*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2)
4314*da0073e9SAndroid Build Coastguard Worker        obj1 = UDO()
4315*da0073e9SAndroid Build Coastguard Worker        obj1_id = id(obj1)
4316*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(obj1.call(x, obj1_id), torch.ones(2))
4317*da0073e9SAndroid Build Coastguard Worker
4318*da0073e9SAndroid Build Coastguard Worker        obj2 = UDO()
4319*da0073e9SAndroid Build Coastguard Worker        # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails.
4320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2))
4321*da0073e9SAndroid Build Coastguard Worker
4322*da0073e9SAndroid Build Coastguard Worker    def test_id_guarded_module(self):
4323*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
4324*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, ref_id):
4325*da0073e9SAndroid Build Coastguard Worker                self_id = id(self)
4326*da0073e9SAndroid Build Coastguard Worker                if self_id == ref_id:
4327*da0073e9SAndroid Build Coastguard Worker                    x = torch.mul(x, 1.0)
4328*da0073e9SAndroid Build Coastguard Worker                else:
4329*da0073e9SAndroid Build Coastguard Worker                    x = torch.mul(x, 0)
4330*da0073e9SAndroid Build Coastguard Worker                return x
4331*da0073e9SAndroid Build Coastguard Worker
4332*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4333*da0073e9SAndroid Build Coastguard Worker
4334*da0073e9SAndroid Build Coastguard Worker        # Make sure we do recompile when id(self) is executed on
4335*da0073e9SAndroid Build Coastguard Worker        # different self objects.
4336*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2)
4337*da0073e9SAndroid Build Coastguard Worker        m1 = M()
4338*da0073e9SAndroid Build Coastguard Worker        m1_id = id(m1)
4339*da0073e9SAndroid Build Coastguard Worker        opt_m1 = torch._dynamo.optimize(cnts, nopython=True)(m1)
4340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_m1(x, m1_id), torch.ones(2))
4341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_m1(x, m1_id), torch.ones(2))
4342*da0073e9SAndroid Build Coastguard Worker
4343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
4344*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 1)
4345*da0073e9SAndroid Build Coastguard Worker
4346*da0073e9SAndroid Build Coastguard Worker        m2 = M()
4347*da0073e9SAndroid Build Coastguard Worker        opt_m2 = torch._dynamo.optimize(cnts, nopython=True)(m2)
4348*da0073e9SAndroid Build Coastguard Worker        # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails.
4349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_m2(x, m1_id), torch.zeros(2))
4350*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
4351*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
4352*da0073e9SAndroid Build Coastguard Worker
4353*da0073e9SAndroid Build Coastguard Worker    def test_id_of_nn_module(self):
4354*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
4355*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, ref_id):
4356*da0073e9SAndroid Build Coastguard Worker                self_id = id(self)
4357*da0073e9SAndroid Build Coastguard Worker                if self_id == ref_id:
4358*da0073e9SAndroid Build Coastguard Worker                    x = torch.mul(x, 1.0)
4359*da0073e9SAndroid Build Coastguard Worker                x = torch.add(x, 1.0)
4360*da0073e9SAndroid Build Coastguard Worker                return x
4361*da0073e9SAndroid Build Coastguard Worker
4362*da0073e9SAndroid Build Coastguard Worker        m = M().eval()
4363*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(1)
4364*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4365*da0073e9SAndroid Build Coastguard Worker        correct_ref_id = id(m)
4366*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
4367*da0073e9SAndroid Build Coastguard Worker        opt_m(data, correct_ref_id)
4368*da0073e9SAndroid Build Coastguard Worker        # Extra op is the recorded equality test (although once
4369*da0073e9SAndroid Build Coastguard Worker        # the trace is flattened this is dead!)
4370*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
4371*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """2""")
4372*da0073e9SAndroid Build Coastguard Worker        else:
4373*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """2""")
4374*da0073e9SAndroid Build Coastguard Worker
4375*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
4376*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4377*da0073e9SAndroid Build Coastguard Worker        incorrect_ref_id = id(m) + 1
4378*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
4379*da0073e9SAndroid Build Coastguard Worker        opt_m(data, incorrect_ref_id)
4380*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
4381*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """1""")
4382*da0073e9SAndroid Build Coastguard Worker        else:
4383*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnts.op_count, """1""")
4384*da0073e9SAndroid Build Coastguard Worker
4385*da0073e9SAndroid Build Coastguard Worker    def test_inline_func_jump_on_tensor_condition(self):
4386*da0073e9SAndroid Build Coastguard Worker        def f1(input):
4387*da0073e9SAndroid Build Coastguard Worker            if input == 0:
4388*da0073e9SAndroid Build Coastguard Worker                return input + 1
4389*da0073e9SAndroid Build Coastguard Worker            else:
4390*da0073e9SAndroid Build Coastguard Worker                return input + 2
4391*da0073e9SAndroid Build Coastguard Worker
4392*da0073e9SAndroid Build Coastguard Worker        def f2(input):
4393*da0073e9SAndroid Build Coastguard Worker            return f1(input)
4394*da0073e9SAndroid Build Coastguard Worker
4395*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4396*da0073e9SAndroid Build Coastguard Worker        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4397*da0073e9SAndroid Build Coastguard Worker        res1 = opt_f2(torch.tensor([1.0]))
4398*da0073e9SAndroid Build Coastguard Worker        res2 = opt_f2(torch.tensor([0.0]))
4399*da0073e9SAndroid Build Coastguard Worker
4400*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, 3)
4401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res2, 1)
4402*da0073e9SAndroid Build Coastguard Worker
4403*da0073e9SAndroid Build Coastguard Worker    def test_frozenset_torch_func_contains(self):
4404*da0073e9SAndroid Build Coastguard Worker        funcs = frozenset([torch.add])
4405*da0073e9SAndroid Build Coastguard Worker
4406*da0073e9SAndroid Build Coastguard Worker        def fn(x, func):
4407*da0073e9SAndroid Build Coastguard Worker            if func in funcs:
4408*da0073e9SAndroid Build Coastguard Worker                x = torch.add(x, 1.0)
4409*da0073e9SAndroid Build Coastguard Worker            x = torch.mul(x, 1.0)
4410*da0073e9SAndroid Build Coastguard Worker            return x
4411*da0073e9SAndroid Build Coastguard Worker
4412*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
4413*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4414*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4415*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, torch.add)
4416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
4417*da0073e9SAndroid Build Coastguard Worker
4418*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
4419*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4420*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
4421*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, torch.mul)
4422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 1)
4423*da0073e9SAndroid Build Coastguard Worker
4424*da0073e9SAndroid Build Coastguard Worker    def test_inline_list_mutation(self):
4425*da0073e9SAndroid Build Coastguard Worker        def f1(x):
4426*da0073e9SAndroid Build Coastguard Worker            x.append(torch.ones(8))
4427*da0073e9SAndroid Build Coastguard Worker            return x
4428*da0073e9SAndroid Build Coastguard Worker
4429*da0073e9SAndroid Build Coastguard Worker        def f2():
4430*da0073e9SAndroid Build Coastguard Worker            x = [torch.ones(6)]
4431*da0073e9SAndroid Build Coastguard Worker            f1(x)
4432*da0073e9SAndroid Build Coastguard Worker            return x
4433*da0073e9SAndroid Build Coastguard Worker
4434*da0073e9SAndroid Build Coastguard Worker        res1 = f2()
4435*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4436*da0073e9SAndroid Build Coastguard Worker        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4437*da0073e9SAndroid Build Coastguard Worker        res2 = opt_f2()
4438*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2))
4439*da0073e9SAndroid Build Coastguard Worker
4440*da0073e9SAndroid Build Coastguard Worker    def test_inline_dict_mutation(self):
4441*da0073e9SAndroid Build Coastguard Worker        def f1(d):
4442*da0073e9SAndroid Build Coastguard Worker            d["c"] = d["a"] + d.pop("b")
4443*da0073e9SAndroid Build Coastguard Worker            return d
4444*da0073e9SAndroid Build Coastguard Worker
4445*da0073e9SAndroid Build Coastguard Worker        def f2():
4446*da0073e9SAndroid Build Coastguard Worker            d = {"a": torch.ones(5), "b": torch.ones(5)}
4447*da0073e9SAndroid Build Coastguard Worker            f1(d)
4448*da0073e9SAndroid Build Coastguard Worker            return d
4449*da0073e9SAndroid Build Coastguard Worker
4450*da0073e9SAndroid Build Coastguard Worker        res1 = f2()
4451*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4452*da0073e9SAndroid Build Coastguard Worker        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4453*da0073e9SAndroid Build Coastguard Worker        res2 = opt_f2()
4454*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2))
4455*da0073e9SAndroid Build Coastguard Worker
4456*da0073e9SAndroid Build Coastguard Worker    def test_inline_local_dict_clear(self):
4457*da0073e9SAndroid Build Coastguard Worker        def f(d):
4458*da0073e9SAndroid Build Coastguard Worker            d.clear()
4459*da0073e9SAndroid Build Coastguard Worker            return d
4460*da0073e9SAndroid Build Coastguard Worker
4461*da0073e9SAndroid Build Coastguard Worker        inp = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)}
4462*da0073e9SAndroid Build Coastguard Worker        out = torch.compile(f, backend="eager", fullgraph=True)(inp)
4463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(out), 0)
4464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(inp), 0)
4465*da0073e9SAndroid Build Coastguard Worker
4466*da0073e9SAndroid Build Coastguard Worker    def test_inline_module_attr_dict_clear(self):
4467*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
4468*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
4469*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4470*da0073e9SAndroid Build Coastguard Worker                self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)}
4471*da0073e9SAndroid Build Coastguard Worker
4472*da0073e9SAndroid Build Coastguard Worker            def forward(self):
4473*da0073e9SAndroid Build Coastguard Worker                self.a.clear()
4474*da0073e9SAndroid Build Coastguard Worker                return self.a
4475*da0073e9SAndroid Build Coastguard Worker
4476*da0073e9SAndroid Build Coastguard Worker        m = MyMod()
4477*da0073e9SAndroid Build Coastguard Worker        out = torch.compile(m, backend="eager", fullgraph=True)()
4478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(out), 0)
4479*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.a), 0)
4480*da0073e9SAndroid Build Coastguard Worker
4481*da0073e9SAndroid Build Coastguard Worker    def test_inline_user_defined_dict_attr_clear(self):
4482*da0073e9SAndroid Build Coastguard Worker        class MyMod:
4483*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
4484*da0073e9SAndroid Build Coastguard Worker                self.a = {"a": torch.randn(2, 2), "b": torch.randn(2, 2)}
4485*da0073e9SAndroid Build Coastguard Worker
4486*da0073e9SAndroid Build Coastguard Worker        def f(obj, inp):
4487*da0073e9SAndroid Build Coastguard Worker            ret = len(obj.a) + inp
4488*da0073e9SAndroid Build Coastguard Worker            obj.a.clear()
4489*da0073e9SAndroid Build Coastguard Worker            return obj.a, ret
4490*da0073e9SAndroid Build Coastguard Worker
4491*da0073e9SAndroid Build Coastguard Worker        m = MyMod()
4492*da0073e9SAndroid Build Coastguard Worker        before_len = len(m.a)
4493*da0073e9SAndroid Build Coastguard Worker        t_inp = torch.ones(1)
4494*da0073e9SAndroid Build Coastguard Worker        d, ret = torch.compile(f, backend="eager", fullgraph=True)(m, t_inp)
4495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(m.a), 0)
4496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(d), 0)
4497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ret, t_inp + before_len)
4498*da0073e9SAndroid Build Coastguard Worker
4499*da0073e9SAndroid Build Coastguard Worker    def test_recursive_inline_list_mutation(self):
4500*da0073e9SAndroid Build Coastguard Worker        def f1(x, y):
4501*da0073e9SAndroid Build Coastguard Worker            x.append(torch.tensor([1.1]))
4502*da0073e9SAndroid Build Coastguard Worker            y.append(torch.tensor([1.2]))
4503*da0073e9SAndroid Build Coastguard Worker            return x, y
4504*da0073e9SAndroid Build Coastguard Worker
4505*da0073e9SAndroid Build Coastguard Worker        def f2(x, y):
4506*da0073e9SAndroid Build Coastguard Worker            x.append(torch.tensor([2.1]))
4507*da0073e9SAndroid Build Coastguard Worker            y.append(torch.tensor([2.2]))
4508*da0073e9SAndroid Build Coastguard Worker            f1(x, y)
4509*da0073e9SAndroid Build Coastguard Worker            return x, y
4510*da0073e9SAndroid Build Coastguard Worker
4511*da0073e9SAndroid Build Coastguard Worker        def f3(x):
4512*da0073e9SAndroid Build Coastguard Worker            x.append(torch.tensor([3.1]))
4513*da0073e9SAndroid Build Coastguard Worker            y = [torch.tensor([3.2])]
4514*da0073e9SAndroid Build Coastguard Worker            f2(x, y)
4515*da0073e9SAndroid Build Coastguard Worker            return x, y
4516*da0073e9SAndroid Build Coastguard Worker
4517*da0073e9SAndroid Build Coastguard Worker        def f4():
4518*da0073e9SAndroid Build Coastguard Worker            x = [torch.tensor([4.1])]
4519*da0073e9SAndroid Build Coastguard Worker            return f3(x)
4520*da0073e9SAndroid Build Coastguard Worker
4521*da0073e9SAndroid Build Coastguard Worker        res1 = f4()
4522*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4523*da0073e9SAndroid Build Coastguard Worker        opt_f4 = torch._dynamo.optimize(cnts)(f4)
4524*da0073e9SAndroid Build Coastguard Worker        res2 = opt_f4()
4525*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2))
4526*da0073e9SAndroid Build Coastguard Worker
4527*da0073e9SAndroid Build Coastguard Worker    def test_sample_input(self):
4528*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_methods_invocations import SampleInput
4529*da0073e9SAndroid Build Coastguard Worker
4530*da0073e9SAndroid Build Coastguard Worker        def fn(sample):
4531*da0073e9SAndroid Build Coastguard Worker            if isinstance(sample.input, torch.Tensor):
4532*da0073e9SAndroid Build Coastguard Worker                return sample.input * 2
4533*da0073e9SAndroid Build Coastguard Worker            return torch.zeros(())
4534*da0073e9SAndroid Build Coastguard Worker
4535*da0073e9SAndroid Build Coastguard Worker        sample = SampleInput(torch.ones(2))
4536*da0073e9SAndroid Build Coastguard Worker        ref = fn(sample)
4537*da0073e9SAndroid Build Coastguard Worker
4538*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
4539*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(sample)
4540*da0073e9SAndroid Build Coastguard Worker
4541*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
4542*da0073e9SAndroid Build Coastguard Worker
4543*da0073e9SAndroid Build Coastguard Worker    def test_release_input_memory(self):
4544*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4])
4545*da0073e9SAndroid Build Coastguard Worker        x_ref = weakref.ref(x)
4546*da0073e9SAndroid Build Coastguard Worker
4547*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4548*da0073e9SAndroid Build Coastguard Worker
4549*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnts)
4550*da0073e9SAndroid Build Coastguard Worker        def foo(x):
4551*da0073e9SAndroid Build Coastguard Worker            return x + x
4552*da0073e9SAndroid Build Coastguard Worker
4553*da0073e9SAndroid Build Coastguard Worker        out = foo(x)
4554*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(out, x + x))
4555*da0073e9SAndroid Build Coastguard Worker        del x
4556*da0073e9SAndroid Build Coastguard Worker        self.assertIs(x_ref(), None)
4557*da0073e9SAndroid Build Coastguard Worker
4558*da0073e9SAndroid Build Coastguard Worker    def test_release_module_memory(self):
4559*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Linear(10, 10)
4560*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([10, 10])
4561*da0073e9SAndroid Build Coastguard Worker        mod_weight_ref = weakref.ref(mod.weight)
4562*da0073e9SAndroid Build Coastguard Worker        mod_ref = weakref.ref(mod)
4563*da0073e9SAndroid Build Coastguard Worker
4564*da0073e9SAndroid Build Coastguard Worker        # Modules that are passed into torch._dynamo optimized functions
4565*da0073e9SAndroid Build Coastguard Worker        # will normally be held onto through the generated GraphModule,
4566*da0073e9SAndroid Build Coastguard Worker        # which contains the modules. remove the reference in this backend
4567*da0073e9SAndroid Build Coastguard Worker        # and test that no additional references are being held.
4568*da0073e9SAndroid Build Coastguard Worker        class NoLeakBackend:
4569*da0073e9SAndroid Build Coastguard Worker            def __call__(self, gm: torch.fx.GraphModule, example_inputs):
4570*da0073e9SAndroid Build Coastguard Worker                gm.mod = None
4571*da0073e9SAndroid Build Coastguard Worker
4572*da0073e9SAndroid Build Coastguard Worker                def foo(*args, **kwargs):
4573*da0073e9SAndroid Build Coastguard Worker                    return (1,)
4574*da0073e9SAndroid Build Coastguard Worker
4575*da0073e9SAndroid Build Coastguard Worker                return foo
4576*da0073e9SAndroid Build Coastguard Worker
4577*da0073e9SAndroid Build Coastguard Worker        no_leak_backend = NoLeakBackend()
4578*da0073e9SAndroid Build Coastguard Worker
4579*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(no_leak_backend)
4580*da0073e9SAndroid Build Coastguard Worker        def foo(mod, x):
4581*da0073e9SAndroid Build Coastguard Worker            return mod(x)
4582*da0073e9SAndroid Build Coastguard Worker
4583*da0073e9SAndroid Build Coastguard Worker        foo(mod, x)
4584*da0073e9SAndroid Build Coastguard Worker        del mod
4585*da0073e9SAndroid Build Coastguard Worker        del x
4586*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(mod_ref(), None)
4587*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(mod_weight_ref(), None)
4588*da0073e9SAndroid Build Coastguard Worker
4589*da0073e9SAndroid Build Coastguard Worker    def test_release_scope_memory(self):
4590*da0073e9SAndroid Build Coastguard Worker        def inner(y):
4591*da0073e9SAndroid Build Coastguard Worker            y
4592*da0073e9SAndroid Build Coastguard Worker
4593*da0073e9SAndroid Build Coastguard Worker        inner = torch._dynamo.optimize("eager")(inner)
4594*da0073e9SAndroid Build Coastguard Worker
4595*da0073e9SAndroid Build Coastguard Worker        p_ref = None
4596*da0073e9SAndroid Build Coastguard Worker
4597*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((10, 10))
4598*da0073e9SAndroid Build Coastguard Worker        inner(x)
4599*da0073e9SAndroid Build Coastguard Worker
4600*da0073e9SAndroid Build Coastguard Worker        p_ref = weakref.ref(x)
4601*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(p_ref() is not None)
4602*da0073e9SAndroid Build Coastguard Worker        del x
4603*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(p_ref() is None)
4604*da0073e9SAndroid Build Coastguard Worker
4605*da0073e9SAndroid Build Coastguard Worker    def test_update_locals_and_stack_uses_shared_cache(self):
4606*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4607*da0073e9SAndroid Build Coastguard Worker            perm = [0, 3, 5]
4608*da0073e9SAndroid Build Coastguard Worker            perm = list(range(min(perm))) + perm
4609*da0073e9SAndroid Build Coastguard Worker            perm.extend(i for i in range(x.dim()) if i not in perm)
4610*da0073e9SAndroid Build Coastguard Worker            return perm
4611*da0073e9SAndroid Build Coastguard Worker
4612*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([2, 2, 2, 2, 2, 2])
4613*da0073e9SAndroid Build Coastguard Worker        res1 = fn(x)
4614*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4615*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
4616*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x)
4617*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, res2))
4618*da0073e9SAndroid Build Coastguard Worker
4619*da0073e9SAndroid Build Coastguard Worker    def test_dict_reconstruct_keeps_original_order(self):
4620*da0073e9SAndroid Build Coastguard Worker        def fn():
4621*da0073e9SAndroid Build Coastguard Worker            modules = collections.OrderedDict([("act", torch.nn.ReLU())])
4622*da0073e9SAndroid Build Coastguard Worker            module_dict = torch.nn.ModuleDict(modules)
4623*da0073e9SAndroid Build Coastguard Worker
4624*da0073e9SAndroid Build Coastguard Worker            next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()}
4625*da0073e9SAndroid Build Coastguard Worker            modules.update(next_modules.items())
4626*da0073e9SAndroid Build Coastguard Worker            module_dict.update(next_modules)
4627*da0073e9SAndroid Build Coastguard Worker            return modules, module_dict
4628*da0073e9SAndroid Build Coastguard Worker
4629*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4630*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
4631*da0073e9SAndroid Build Coastguard Worker        modules, module_dict = opt_fn()
4632*da0073e9SAndroid Build Coastguard Worker
4633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(module_dict), len(modules))
4634*da0073e9SAndroid Build Coastguard Worker        for k1, m2 in zip(modules, module_dict.children()):
4635*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(modules[k1] is m2)
4636*da0073e9SAndroid Build Coastguard Worker
4637*da0073e9SAndroid Build Coastguard Worker    def test_side_effects_codegen_update_mutated(self):
4638*da0073e9SAndroid Build Coastguard Worker        # codegen to update mutated variables with side effect
4639*da0073e9SAndroid Build Coastguard Worker        # should after stack value's codegen
4640*da0073e9SAndroid Build Coastguard Worker        def f1(x):
4641*da0073e9SAndroid Build Coastguard Worker            alist = [x]
4642*da0073e9SAndroid Build Coastguard Worker            alist.append(x + 1)
4643*da0073e9SAndroid Build Coastguard Worker            alist[0].sum().item()  # graph break
4644*da0073e9SAndroid Build Coastguard Worker            res = alist.pop()
4645*da0073e9SAndroid Build Coastguard Worker            res.sum().item()  # graph break
4646*da0073e9SAndroid Build Coastguard Worker            return res
4647*da0073e9SAndroid Build Coastguard Worker
4648*da0073e9SAndroid Build Coastguard Worker        def f2(a, b):
4649*da0073e9SAndroid Build Coastguard Worker            d = {"a": a + 1, "b": b + 2}
4650*da0073e9SAndroid Build Coastguard Worker            x = d.pop("b")
4651*da0073e9SAndroid Build Coastguard Worker            x.sum().item()  # graph break
4652*da0073e9SAndroid Build Coastguard Worker            y = d["a"] + x
4653*da0073e9SAndroid Build Coastguard Worker            y.sum().item()  # graph break
4654*da0073e9SAndroid Build Coastguard Worker            d["c"] = y
4655*da0073e9SAndroid Build Coastguard Worker            return d
4656*da0073e9SAndroid Build Coastguard Worker
4657*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([2, 3])
4658*da0073e9SAndroid Build Coastguard Worker        a = torch.rand([5, 6])
4659*da0073e9SAndroid Build Coastguard Worker        b = torch.rand([5, 6])
4660*da0073e9SAndroid Build Coastguard Worker        res11 = f1(x)
4661*da0073e9SAndroid Build Coastguard Worker        res21 = f2(a, b)
4662*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4663*da0073e9SAndroid Build Coastguard Worker        opt_f1 = torch._dynamo.optimize(cnts)(f1)
4664*da0073e9SAndroid Build Coastguard Worker        opt_f2 = torch._dynamo.optimize(cnts)(f2)
4665*da0073e9SAndroid Build Coastguard Worker        res12 = opt_f1(x)
4666*da0073e9SAndroid Build Coastguard Worker        res22 = opt_f2(a, b)
4667*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res11, res12))
4668*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res21, res22))
4669*da0073e9SAndroid Build Coastguard Worker
4670*da0073e9SAndroid Build Coastguard Worker    def test_list_append_return_none(self):
4671*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4672*da0073e9SAndroid Build Coastguard Worker            alist = []
4673*da0073e9SAndroid Build Coastguard Worker            blist = alist.append(x + 1)
4674*da0073e9SAndroid Build Coastguard Worker            return alist, blist
4675*da0073e9SAndroid Build Coastguard Worker
4676*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([2.3])
4677*da0073e9SAndroid Build Coastguard Worker        res = fn(x)
4678*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4679*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
4680*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x)
4681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res2)
4682*da0073e9SAndroid Build Coastguard Worker
4683*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
4684*da0073e9SAndroid Build Coastguard Worker    def test_tensor_ctor_list_of_tensor(self):
4685*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4686*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([x], dtype=torch.int64)
4687*da0073e9SAndroid Build Coastguard Worker
4688*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(20)
4689*da0073e9SAndroid Build Coastguard Worker        res = fn(x)
4690*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4691*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
4692*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x)
4693*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res2)
4694*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
4695*da0073e9SAndroid Build Coastguard Worker
4696*da0073e9SAndroid Build Coastguard Worker    def test_tensor_types(self):
4697*da0073e9SAndroid Build Coastguard Worker        def fn(dtype, tensor_type):
4698*da0073e9SAndroid Build Coastguard Worker            x = torch.empty(4, dtype=dtype)
4699*da0073e9SAndroid Build Coastguard Worker            assert isinstance(x, tensor_type)
4700*da0073e9SAndroid Build Coastguard Worker
4701*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
4702*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.float32, torch.FloatTensor)
4703*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.float64, torch.DoubleTensor)
4704*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.float16, torch.HalfTensor)
4705*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.bfloat16, torch.BFloat16Tensor)
4706*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.uint8, torch.ByteTensor)
4707*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.int8, torch.CharTensor)
4708*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.int64, torch.LongTensor)
4709*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.int, torch.IntTensor)
4710*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.int16, torch.ShortTensor)
4711*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.bool, torch.BoolTensor)
4712*da0073e9SAndroid Build Coastguard Worker
4713*da0073e9SAndroid Build Coastguard Worker    def test_nan(self):
4714*da0073e9SAndroid Build Coastguard Worker        def f(x, n):
4715*da0073e9SAndroid Build Coastguard Worker            return x * 2 + n
4716*da0073e9SAndroid Build Coastguard Worker
4717*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
4718*da0073e9SAndroid Build Coastguard Worker        n = float("nan")
4719*da0073e9SAndroid Build Coastguard Worker
4720*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
4721*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize(cnts)(f)
4722*da0073e9SAndroid Build Coastguard Worker        opt_f(x, n)
4723*da0073e9SAndroid Build Coastguard Worker        opt_f(x, n)
4724*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
4725*da0073e9SAndroid Build Coastguard Worker
4726*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
4727*da0073e9SAndroid Build Coastguard Worker    def test_item(self):
4728*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
4729*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4730*da0073e9SAndroid Build Coastguard Worker                z = torch.max(x)
4731*da0073e9SAndroid Build Coastguard Worker                return z.int().item()
4732*da0073e9SAndroid Build Coastguard Worker
4733*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[10.6763, 11.7445, -2.2369]])
4734*da0073e9SAndroid Build Coastguard Worker        model = MyMod()
4735*da0073e9SAndroid Build Coastguard Worker        y = torch._dynamo.optimize("eager", nopython=True)(model)(x)
4736*da0073e9SAndroid Build Coastguard Worker
4737*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, 11)
4738*da0073e9SAndroid Build Coastguard Worker
4739*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
4740*da0073e9SAndroid Build Coastguard Worker    def test_item_changes(self):
4741*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
4742*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4743*da0073e9SAndroid Build Coastguard Worker                z = torch.max(x)
4744*da0073e9SAndroid Build Coastguard Worker                return z.int().item()
4745*da0073e9SAndroid Build Coastguard Worker
4746*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[10.6763, 11.7445, -2.2369]])
4747*da0073e9SAndroid Build Coastguard Worker        model = MyMod()
4748*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
4749*da0073e9SAndroid Build Coastguard Worker        y = opt_model(x)
4750*da0073e9SAndroid Build Coastguard Worker        z = opt_model(torch.tensor([[y - 5, y + 10, y + 50]]))
4751*da0073e9SAndroid Build Coastguard Worker
4752*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, 11)
4753*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, 61)
4754*da0073e9SAndroid Build Coastguard Worker
4755*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
4756*da0073e9SAndroid Build Coastguard Worker    def test_item_changes_new_shape(self):
4757*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
4758*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4759*da0073e9SAndroid Build Coastguard Worker                z = torch.max(x)
4760*da0073e9SAndroid Build Coastguard Worker                return z.int().item()
4761*da0073e9SAndroid Build Coastguard Worker
4762*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[10.6763, 11.7445, -2.2369]])
4763*da0073e9SAndroid Build Coastguard Worker        model = MyMod()
4764*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
4765*da0073e9SAndroid Build Coastguard Worker        y = opt_model(x)
4766*da0073e9SAndroid Build Coastguard Worker        z = opt_model(torch.tensor([[y - 5, y + 50], [y + 5, y - 50]]))
4767*da0073e9SAndroid Build Coastguard Worker
4768*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, 11)
4769*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, 61)
4770*da0073e9SAndroid Build Coastguard Worker
4771*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("https://github.com/pytorch/pytorch/issues/99726")
4772*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_fancy_ctor1(self):
4773*da0073e9SAndroid Build Coastguard Worker        rand_5 = torch.randn(5)
4774*da0073e9SAndroid Build Coastguard Worker        rand_3_5 = torch.randn(3, 5)
4775*da0073e9SAndroid Build Coastguard Worker        target = torch.empty(3, dtype=torch.long).random_(5)
4776*da0073e9SAndroid Build Coastguard Worker
4777*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.CrossEntropyLoss(
4778*da0073e9SAndroid Build Coastguard Worker            weight=rand_5, reduce=False, label_smoothing=0.5
4779*da0073e9SAndroid Build Coastguard Worker        )
4780*da0073e9SAndroid Build Coastguard Worker        opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
4781*da0073e9SAndroid Build Coastguard Worker        input = rand_3_5
4782*da0073e9SAndroid Build Coastguard Worker        dynamo_output = opt_loss(input, target)
4783*da0073e9SAndroid Build Coastguard Worker
4784*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.CrossEntropyLoss(
4785*da0073e9SAndroid Build Coastguard Worker            weight=rand_5, reduce=False, label_smoothing=0.5
4786*da0073e9SAndroid Build Coastguard Worker        )
4787*da0073e9SAndroid Build Coastguard Worker        input = rand_3_5
4788*da0073e9SAndroid Build Coastguard Worker        output = loss(input, target)
4789*da0073e9SAndroid Build Coastguard Worker
4790*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(dynamo_output, output))
4791*da0073e9SAndroid Build Coastguard Worker
4792*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_fancy_ctor2(self):
4793*da0073e9SAndroid Build Coastguard Worker        rand_3_5 = torch.randn(3, 5)
4794*da0073e9SAndroid Build Coastguard Worker        target = torch.empty(3, dtype=torch.long).random_(5)
4795*da0073e9SAndroid Build Coastguard Worker
4796*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5)
4797*da0073e9SAndroid Build Coastguard Worker        opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
4798*da0073e9SAndroid Build Coastguard Worker        input = rand_3_5
4799*da0073e9SAndroid Build Coastguard Worker        dynamo_output = opt_loss(input, target)
4800*da0073e9SAndroid Build Coastguard Worker
4801*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5)
4802*da0073e9SAndroid Build Coastguard Worker        input = rand_3_5
4803*da0073e9SAndroid Build Coastguard Worker        output = loss(input, target)
4804*da0073e9SAndroid Build Coastguard Worker
4805*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(dynamo_output, output))
4806*da0073e9SAndroid Build Coastguard Worker
4807*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_simple_ctor(self):
4808*da0073e9SAndroid Build Coastguard Worker        output = None
4809*da0073e9SAndroid Build Coastguard Worker        rand_3_5 = torch.randn(3, 5)
4810*da0073e9SAndroid Build Coastguard Worker        target = torch.empty(3, dtype=torch.long).random_(5)
4811*da0073e9SAndroid Build Coastguard Worker
4812*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.CrossEntropyLoss()
4813*da0073e9SAndroid Build Coastguard Worker        opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
4814*da0073e9SAndroid Build Coastguard Worker        input = rand_3_5
4815*da0073e9SAndroid Build Coastguard Worker        dynamo_output = opt_loss(input, target)
4816*da0073e9SAndroid Build Coastguard Worker
4817*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.CrossEntropyLoss()
4818*da0073e9SAndroid Build Coastguard Worker        input = rand_3_5
4819*da0073e9SAndroid Build Coastguard Worker        output = loss(input, target)
4820*da0073e9SAndroid Build Coastguard Worker
4821*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(dynamo_output, output))
4822*da0073e9SAndroid Build Coastguard Worker
4823*da0073e9SAndroid Build Coastguard Worker    def test_nn_functional_reduction(self):
4824*da0073e9SAndroid Build Coastguard Worker        def fn(loss, reduction):
4825*da0073e9SAndroid Build Coastguard Worker            reduction_enum = F._Reduction.get_enum(reduction)
4826*da0073e9SAndroid Build Coastguard Worker            if reduction_enum == 0:
4827*da0073e9SAndroid Build Coastguard Worker                return loss
4828*da0073e9SAndroid Build Coastguard Worker            elif reduction_enum == 1:
4829*da0073e9SAndroid Build Coastguard Worker                return loss.mean()
4830*da0073e9SAndroid Build Coastguard Worker            elif reduction_enum == 2:
4831*da0073e9SAndroid Build Coastguard Worker                return loss.sum()
4832*da0073e9SAndroid Build Coastguard Worker
4833*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([3, 5])
4834*da0073e9SAndroid Build Coastguard Worker        y = "mean"
4835*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
4836*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
4837*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
4838*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
4839*da0073e9SAndroid Build Coastguard Worker
4840*da0073e9SAndroid Build Coastguard Worker    def test_large_reduction_list(self):
4841*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
4842*da0073e9SAndroid Build Coastguard Worker        device = "cpu"
4843*da0073e9SAndroid Build Coastguard Worker
4844*da0073e9SAndroid Build Coastguard Worker        def check_sum_all(tensor: torch.Tensor) -> None:
4845*da0073e9SAndroid Build Coastguard Worker            pylist = tensor.reshape(-1).tolist()
4846*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(tensor.sum(), torch.tensor(sum(pylist))))
4847*da0073e9SAndroid Build Coastguard Worker
4848*da0073e9SAndroid Build Coastguard Worker        check_sum_all(torch.randn(200000, dtype=dtype, device=device))
4849*da0073e9SAndroid Build Coastguard Worker
4850*da0073e9SAndroid Build Coastguard Worker    def test_raise_on_backend_error(self):
4851*da0073e9SAndroid Build Coastguard Worker        def my_compiler(gm, _):
4852*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("duck!")
4853*da0073e9SAndroid Build Coastguard Worker
4854*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(my_compiler)
4855*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
4856*da0073e9SAndroid Build Coastguard Worker            return a + b / (a - b)
4857*da0073e9SAndroid Build Coastguard Worker
4858*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
4859*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.exc.BackendCompilerFailed,
4860*da0073e9SAndroid Build Coastguard Worker            lambda: fn(torch.randn(10), torch.randn(10)),
4861*da0073e9SAndroid Build Coastguard Worker        )
4862*da0073e9SAndroid Build Coastguard Worker
4863*da0073e9SAndroid Build Coastguard Worker    def test_named_parameters(self):
4864*da0073e9SAndroid Build Coastguard Worker        n_embd = 768
4865*da0073e9SAndroid Build Coastguard Worker        block_size = 128
4866*da0073e9SAndroid Build Coastguard Worker        vocab_size = 65
4867*da0073e9SAndroid Build Coastguard Worker        embd_pdrop = 0.1
4868*da0073e9SAndroid Build Coastguard Worker
4869*da0073e9SAndroid Build Coastguard Worker        class MyModel2(torch.nn.Module):
4870*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
4871*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4872*da0073e9SAndroid Build Coastguard Worker                self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
4873*da0073e9SAndroid Build Coastguard Worker                self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
4874*da0073e9SAndroid Build Coastguard Worker                self.drop = torch.nn.Dropout(embd_pdrop)
4875*da0073e9SAndroid Build Coastguard Worker
4876*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4877*da0073e9SAndroid Build Coastguard Worker                return x
4878*da0073e9SAndroid Build Coastguard Worker
4879*da0073e9SAndroid Build Coastguard Worker        class MyModel(torch.nn.Module):
4880*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
4881*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4882*da0073e9SAndroid Build Coastguard Worker                self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
4883*da0073e9SAndroid Build Coastguard Worker                self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
4884*da0073e9SAndroid Build Coastguard Worker                self.drop = torch.nn.Dropout(embd_pdrop)
4885*da0073e9SAndroid Build Coastguard Worker                self.submod2 = MyModel2()
4886*da0073e9SAndroid Build Coastguard Worker
4887*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4888*da0073e9SAndroid Build Coastguard Worker                return x
4889*da0073e9SAndroid Build Coastguard Worker
4890*da0073e9SAndroid Build Coastguard Worker        # Regular
4891*da0073e9SAndroid Build Coastguard Worker        params = []
4892*da0073e9SAndroid Build Coastguard Worker        mod = MyModel()
4893*da0073e9SAndroid Build Coastguard Worker        actual_params = list(mod.named_parameters())
4894*da0073e9SAndroid Build Coastguard Worker
4895*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
4896*da0073e9SAndroid Build Coastguard Worker        def fn():
4897*da0073e9SAndroid Build Coastguard Worker            return list(mod.named_parameters())
4898*da0073e9SAndroid Build Coastguard Worker
4899*da0073e9SAndroid Build Coastguard Worker        params = fn()
4900*da0073e9SAndroid Build Coastguard Worker
4901*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(actual_params), len(params))
4902*da0073e9SAndroid Build Coastguard Worker        for idx in range(len(params)):
4903*da0073e9SAndroid Build Coastguard Worker            k_a, v_a = actual_params[idx]
4904*da0073e9SAndroid Build Coastguard Worker            k, v = params[idx]
4905*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(k_a, k)
4906*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(v_a, v))
4907*da0073e9SAndroid Build Coastguard Worker
4908*da0073e9SAndroid Build Coastguard Worker        # Prefix
4909*da0073e9SAndroid Build Coastguard Worker        params = []
4910*da0073e9SAndroid Build Coastguard Worker        mod = MyModel()
4911*da0073e9SAndroid Build Coastguard Worker        actual_params = list(mod.named_parameters(prefix="foo"))
4912*da0073e9SAndroid Build Coastguard Worker
4913*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
4914*da0073e9SAndroid Build Coastguard Worker        def fn1():
4915*da0073e9SAndroid Build Coastguard Worker            return list(mod.named_parameters(prefix="foo"))
4916*da0073e9SAndroid Build Coastguard Worker
4917*da0073e9SAndroid Build Coastguard Worker        params = fn1()
4918*da0073e9SAndroid Build Coastguard Worker
4919*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(actual_params), len(params))
4920*da0073e9SAndroid Build Coastguard Worker        for idx in range(len(params)):
4921*da0073e9SAndroid Build Coastguard Worker            k_a, v_a = actual_params[idx]
4922*da0073e9SAndroid Build Coastguard Worker            k, v = params[idx]
4923*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(k_a, k)
4924*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(v_a, v))
4925*da0073e9SAndroid Build Coastguard Worker
4926*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
4927*da0073e9SAndroid Build Coastguard Worker    def test_module_complex_iter(self):
4928*da0073e9SAndroid Build Coastguard Worker        n_embd = 768
4929*da0073e9SAndroid Build Coastguard Worker        block_size = 128
4930*da0073e9SAndroid Build Coastguard Worker        vocab_size = 65
4931*da0073e9SAndroid Build Coastguard Worker        embd_pdrop = 0.1
4932*da0073e9SAndroid Build Coastguard Worker
4933*da0073e9SAndroid Build Coastguard Worker        class FakeGPT(torch.nn.Module):
4934*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
4935*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4936*da0073e9SAndroid Build Coastguard Worker                self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
4937*da0073e9SAndroid Build Coastguard Worker                self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
4938*da0073e9SAndroid Build Coastguard Worker                self.drop = torch.nn.Dropout(embd_pdrop)
4939*da0073e9SAndroid Build Coastguard Worker                self.ln_f = torch.nn.LayerNorm(n_embd)
4940*da0073e9SAndroid Build Coastguard Worker                self.head = torch.nn.Linear(n_embd, vocab_size, bias=False)
4941*da0073e9SAndroid Build Coastguard Worker
4942*da0073e9SAndroid Build Coastguard Worker                self.block_size = block_size
4943*da0073e9SAndroid Build Coastguard Worker                self.names = []
4944*da0073e9SAndroid Build Coastguard Worker
4945*da0073e9SAndroid Build Coastguard Worker            def forward(self, idx, targets=None):
4946*da0073e9SAndroid Build Coastguard Worker                b, t = idx.size()
4947*da0073e9SAndroid Build Coastguard Worker                assert (
4948*da0073e9SAndroid Build Coastguard Worker                    t <= self.block_size
4949*da0073e9SAndroid Build Coastguard Worker                ), "Cannot forward, model block size is exhausted."
4950*da0073e9SAndroid Build Coastguard Worker
4951*da0073e9SAndroid Build Coastguard Worker                # forward the GPT model
4952*da0073e9SAndroid Build Coastguard Worker                token_embeddings = self.tok_emb(
4953*da0073e9SAndroid Build Coastguard Worker                    idx
4954*da0073e9SAndroid Build Coastguard Worker                )  # each index maps to a (learnable) vector
4955*da0073e9SAndroid Build Coastguard Worker                position_embeddings = self.pos_emb[
4956*da0073e9SAndroid Build Coastguard Worker                    :, :t, :
4957*da0073e9SAndroid Build Coastguard Worker                ]  # each position maps to a (learnable) vector
4958*da0073e9SAndroid Build Coastguard Worker                x = self.drop(token_embeddings + position_embeddings)
4959*da0073e9SAndroid Build Coastguard Worker                x = self.blocks(x)
4960*da0073e9SAndroid Build Coastguard Worker                x = self.ln_f(x)
4961*da0073e9SAndroid Build Coastguard Worker                logits = self.head(x)
4962*da0073e9SAndroid Build Coastguard Worker
4963*da0073e9SAndroid Build Coastguard Worker                # if we are given some desired targets also calculate the loss
4964*da0073e9SAndroid Build Coastguard Worker                loss = None
4965*da0073e9SAndroid Build Coastguard Worker                if targets is not None:
4966*da0073e9SAndroid Build Coastguard Worker                    loss = F.cross_entropy(
4967*da0073e9SAndroid Build Coastguard Worker                        logits.view(-1, logits.size(-1)), targets.view(-1)
4968*da0073e9SAndroid Build Coastguard Worker                    )
4969*da0073e9SAndroid Build Coastguard Worker
4970*da0073e9SAndroid Build Coastguard Worker                return logits, loss
4971*da0073e9SAndroid Build Coastguard Worker
4972*da0073e9SAndroid Build Coastguard Worker            def foo(self, memo=None, prefix="", remove_duplicate=False):
4973*da0073e9SAndroid Build Coastguard Worker                for mn, m in self.named_modules(
4974*da0073e9SAndroid Build Coastguard Worker                    memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
4975*da0073e9SAndroid Build Coastguard Worker                ):
4976*da0073e9SAndroid Build Coastguard Worker                    for pn, p in self.named_parameters():
4977*da0073e9SAndroid Build Coastguard Worker                        fpn = f"{mn}.{pn}" if mn else pn
4978*da0073e9SAndroid Build Coastguard Worker                        self.names.append(fpn)
4979*da0073e9SAndroid Build Coastguard Worker
4980*da0073e9SAndroid Build Coastguard Worker        # Test plain recurse
4981*da0073e9SAndroid Build Coastguard Worker        model_a = FakeGPT()
4982*da0073e9SAndroid Build Coastguard Worker        model_a.foo()
4983*da0073e9SAndroid Build Coastguard Worker        a_names = model_a.names
4984*da0073e9SAndroid Build Coastguard Worker
4985*da0073e9SAndroid Build Coastguard Worker        model_b = FakeGPT()
4986*da0073e9SAndroid Build Coastguard Worker        opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b)
4987*da0073e9SAndroid Build Coastguard Worker        opt_model_b.foo()
4988*da0073e9SAndroid Build Coastguard Worker
4989*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_names, model_b.names)
4990*da0073e9SAndroid Build Coastguard Worker
4991*da0073e9SAndroid Build Coastguard Worker        # Test with prefix
4992*da0073e9SAndroid Build Coastguard Worker        model_a = FakeGPT()
4993*da0073e9SAndroid Build Coastguard Worker        model_a.foo(prefix="abc")
4994*da0073e9SAndroid Build Coastguard Worker        a_names = model_a.names
4995*da0073e9SAndroid Build Coastguard Worker
4996*da0073e9SAndroid Build Coastguard Worker        model_b = FakeGPT()
4997*da0073e9SAndroid Build Coastguard Worker        opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b)
4998*da0073e9SAndroid Build Coastguard Worker        opt_model_b.foo(prefix="abc")
4999*da0073e9SAndroid Build Coastguard Worker
5000*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_names, model_b.names)
5001*da0073e9SAndroid Build Coastguard Worker
5002*da0073e9SAndroid Build Coastguard Worker    def test_numpy_variable_isinstance(self):
5003*da0073e9SAndroid Build Coastguard Worker        def fn(x, m):
5004*da0073e9SAndroid Build Coastguard Worker            if isinstance(m, np.ndarray):
5005*da0073e9SAndroid Build Coastguard Worker                return x + 1
5006*da0073e9SAndroid Build Coastguard Worker            else:
5007*da0073e9SAndroid Build Coastguard Worker                return x - 1
5008*da0073e9SAndroid Build Coastguard Worker
5009*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([2.3])
5010*da0073e9SAndroid Build Coastguard Worker        m = np.array([1, 2, 3])
5011*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, m)
5012*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
5013*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
5014*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, m)
5015*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
5016*da0073e9SAndroid Build Coastguard Worker
5017*da0073e9SAndroid Build Coastguard Worker        # Test now the other path
5018*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, x)
5019*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, x)
5020*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
5021*da0073e9SAndroid Build Coastguard Worker
5022*da0073e9SAndroid Build Coastguard Worker    def test_tensor_dot_grad_no_graph_break(self):
5023*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
5024*da0073e9SAndroid Build Coastguard Worker            y = 3 * a**3 - b**2
5025*da0073e9SAndroid Build Coastguard Worker            y.backward(gradient=torch.tensor([1.0, 1.0]))
5026*da0073e9SAndroid Build Coastguard Worker            b.grad.zero_()
5027*da0073e9SAndroid Build Coastguard Worker            return a.grad, b.grad
5028*da0073e9SAndroid Build Coastguard Worker
5029*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([2.0, 3.0], requires_grad=True)
5030*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([6.0, 4.0], requires_grad=True)
5031*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
5032*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
5033*da0073e9SAndroid Build Coastguard Worker        _, b_grad = opt_fn(a, b)
5034*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0])))
5035*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
5036*da0073e9SAndroid Build Coastguard Worker
5037*da0073e9SAndroid Build Coastguard Worker    def test_torch_nn_parameter_isinstance(self):
5038*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5039*da0073e9SAndroid Build Coastguard Worker            a = torch.nn.Parameter(torch.rand(2, 3))
5040*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, torch.Tensor):
5041*da0073e9SAndroid Build Coastguard Worker                return x + 1
5042*da0073e9SAndroid Build Coastguard Worker            else:
5043*da0073e9SAndroid Build Coastguard Worker                return x - 1
5044*da0073e9SAndroid Build Coastguard Worker
5045*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([2.5])
5046*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
5047*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
5048*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
5049*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
5050*da0073e9SAndroid Build Coastguard Worker
5051*da0073e9SAndroid Build Coastguard Worker    def _optimize_then_check_exp(
5052*da0073e9SAndroid Build Coastguard Worker        self, foo, args, cnt, exp_out, exp_frame_count, exp_n_cached_backend
5053*da0073e9SAndroid Build Coastguard Worker    ):
5054*da0073e9SAndroid Build Coastguard Worker        opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args)
5055*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(exp_out, opt_out)
5056*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, exp_frame_count)
5057*da0073e9SAndroid Build Coastguard Worker
5058*da0073e9SAndroid Build Coastguard Worker    def test_backend_match_guard(self):
5059*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([3, 4])
5060*da0073e9SAndroid Build Coastguard Worker
5061*da0073e9SAndroid Build Coastguard Worker        def foo(x):
5062*da0073e9SAndroid Build Coastguard Worker            return x.sin() + x.cos()
5063*da0073e9SAndroid Build Coastguard Worker
5064*da0073e9SAndroid Build Coastguard Worker        def foo_graph_break(x):
5065*da0073e9SAndroid Build Coastguard Worker            a = x.sin()
5066*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
5067*da0073e9SAndroid Build Coastguard Worker            b = x.cos()
5068*da0073e9SAndroid Build Coastguard Worker            return a + b
5069*da0073e9SAndroid Build Coastguard Worker
5070*da0073e9SAndroid Build Coastguard Worker        eager_record_backend = torch._dynamo.testing.EagerAndRecordGraphs()
5071*da0073e9SAndroid Build Coastguard Worker        backends = [eager_record_backend, "eager"]
5072*da0073e9SAndroid Build Coastguard Worker
5073*da0073e9SAndroid Build Coastguard Worker        # We intentionally don't reset dynamo for each backend so that we can test
5074*da0073e9SAndroid Build Coastguard Worker        # 1. dynamo doesn't recompile when backend stays the same, i.e. frame_count doesn't increase
5075*da0073e9SAndroid Build Coastguard Worker        # 2. dynamo recompiles when backend changes, i.e. frame_count is non-zero for next backend
5076*da0073e9SAndroid Build Coastguard Worker        def test_recompile(foo, *, exp_frame_count):
5077*da0073e9SAndroid Build Coastguard Worker            eager_result = foo(x)
5078*da0073e9SAndroid Build Coastguard Worker            for i, backend in enumerate(backends):
5079*da0073e9SAndroid Build Coastguard Worker                cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
5080*da0073e9SAndroid Build Coastguard Worker                # Run opt_f multiple times to make sure dynamo doesn't recompile.
5081*da0073e9SAndroid Build Coastguard Worker                # Specifically, frame_count doesn't increase
5082*da0073e9SAndroid Build Coastguard Worker                # the number of cached backends is i + 2 because we have the optimizing backend + None
5083*da0073e9SAndroid Build Coastguard Worker                self._optimize_then_check_exp(
5084*da0073e9SAndroid Build Coastguard Worker                    foo, (x,), cnt, eager_result, exp_frame_count, i + 2
5085*da0073e9SAndroid Build Coastguard Worker                )
5086*da0073e9SAndroid Build Coastguard Worker                self._optimize_then_check_exp(
5087*da0073e9SAndroid Build Coastguard Worker                    foo, (x,), cnt, eager_result, exp_frame_count, i + 2
5088*da0073e9SAndroid Build Coastguard Worker                )
5089*da0073e9SAndroid Build Coastguard Worker                self._optimize_then_check_exp(
5090*da0073e9SAndroid Build Coastguard Worker                    foo, (x,), cnt, eager_result, exp_frame_count, i + 2
5091*da0073e9SAndroid Build Coastguard Worker                )
5092*da0073e9SAndroid Build Coastguard Worker
5093*da0073e9SAndroid Build Coastguard Worker        test_recompile(foo, exp_frame_count=1)
5094*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
5095*da0073e9SAndroid Build Coastguard Worker        test_recompile(foo_graph_break, exp_frame_count=2)
5096*da0073e9SAndroid Build Coastguard Worker
5097*da0073e9SAndroid Build Coastguard Worker    def test_backend_match_guard_multi_threads(self):
5098*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([3, 4])
5099*da0073e9SAndroid Build Coastguard Worker
5100*da0073e9SAndroid Build Coastguard Worker        def foo(x):
5101*da0073e9SAndroid Build Coastguard Worker            return x.sin() + x.cos()
5102*da0073e9SAndroid Build Coastguard Worker
5103*da0073e9SAndroid Build Coastguard Worker        def compile_then_check_exp(foo, args, cnt, eager_result, exp_frame_count):
5104*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
5105*da0073e9SAndroid Build Coastguard Worker                opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args)
5106*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(opt_out, eager_result)
5107*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, exp_frame_count)
5108*da0073e9SAndroid Build Coastguard Worker            thread_success[threading.current_thread()] = True
5109*da0073e9SAndroid Build Coastguard Worker
5110*da0073e9SAndroid Build Coastguard Worker        eager_record_backend = torch._dynamo.testing.EagerAndRecordGraphs()
5111*da0073e9SAndroid Build Coastguard Worker        backends = [eager_record_backend, "eager"]
5112*da0073e9SAndroid Build Coastguard Worker
5113*da0073e9SAndroid Build Coastguard Worker        # Test dynamo recompiles but only caches a single backend for each thread
5114*da0073e9SAndroid Build Coastguard Worker        eager_result = foo(x)
5115*da0073e9SAndroid Build Coastguard Worker        # cnt and None
5116*da0073e9SAndroid Build Coastguard Worker        exp_frame_count = 1
5117*da0073e9SAndroid Build Coastguard Worker        threads = []
5118*da0073e9SAndroid Build Coastguard Worker        thread_success = {}
5119*da0073e9SAndroid Build Coastguard Worker        for i, backend in enumerate(backends):
5120*da0073e9SAndroid Build Coastguard Worker            cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
5121*da0073e9SAndroid Build Coastguard Worker            thread = threading.Thread(
5122*da0073e9SAndroid Build Coastguard Worker                target=compile_then_check_exp,
5123*da0073e9SAndroid Build Coastguard Worker                args=(
5124*da0073e9SAndroid Build Coastguard Worker                    foo,
5125*da0073e9SAndroid Build Coastguard Worker                    (x,),
5126*da0073e9SAndroid Build Coastguard Worker                    cnt,
5127*da0073e9SAndroid Build Coastguard Worker                    eager_result,
5128*da0073e9SAndroid Build Coastguard Worker                    exp_frame_count,
5129*da0073e9SAndroid Build Coastguard Worker                ),
5130*da0073e9SAndroid Build Coastguard Worker            )
5131*da0073e9SAndroid Build Coastguard Worker            threads.append(thread)
5132*da0073e9SAndroid Build Coastguard Worker            thread.start()
5133*da0073e9SAndroid Build Coastguard Worker
5134*da0073e9SAndroid Build Coastguard Worker        # Wait for all threads to finish
5135*da0073e9SAndroid Build Coastguard Worker        for thread in threads:
5136*da0073e9SAndroid Build Coastguard Worker            thread.join()
5137*da0073e9SAndroid Build Coastguard Worker
5138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(thread_success), len(threads))
5139*da0073e9SAndroid Build Coastguard Worker
5140*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_min_operator_with_shape(self):
5141*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
5142*da0073e9SAndroid Build Coastguard Worker        def f(x, a):
5143*da0073e9SAndroid Build Coastguard Worker            return min(x.shape[0], a)
5144*da0073e9SAndroid Build Coastguard Worker
5145*da0073e9SAndroid Build Coastguard Worker        result = f(torch.ones(6), 3)
5146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, 3)
5147*da0073e9SAndroid Build Coastguard Worker
5148*da0073e9SAndroid Build Coastguard Worker    def test_onnx_shape_as_tensor(self):
5149*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
5150*da0073e9SAndroid Build Coastguard Worker        def f(x):
5151*da0073e9SAndroid Build Coastguard Worker            return 1 + torch._shape_as_tensor(x)[0]
5152*da0073e9SAndroid Build Coastguard Worker
5153*da0073e9SAndroid Build Coastguard Worker        gm, _ = torch._dynamo.export(f)(torch.ones(6))
5154*da0073e9SAndroid Build Coastguard Worker
5155*da0073e9SAndroid Build Coastguard Worker        input_one_dim = torch.ones(6)
5156*da0073e9SAndroid Build Coastguard Worker        input_two_dims = torch.ones(7, 4)
5157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(input_one_dim), 7)
5158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(input_two_dims), 8)
5159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(input_two_dims), 8)
5160*da0073e9SAndroid Build Coastguard Worker
5161*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
5162*da0073e9SAndroid Build Coastguard Worker        def f_onnx(x):
5163*da0073e9SAndroid Build Coastguard Worker            return 1 + torch.onnx.operators.shape_as_tensor(x)[0]
5164*da0073e9SAndroid Build Coastguard Worker
5165*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f_onnx(input_one_dim), 7)
5166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f_onnx(input_two_dims), 8)
5167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f_onnx(input_two_dims), 8)
5168*da0073e9SAndroid Build Coastguard Worker
5169*da0073e9SAndroid Build Coastguard Worker    def test_cond(self):
5170*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import cond
5171*da0073e9SAndroid Build Coastguard Worker
5172*da0073e9SAndroid Build Coastguard Worker        def true_fn(x):
5173*da0073e9SAndroid Build Coastguard Worker            return x.sin()
5174*da0073e9SAndroid Build Coastguard Worker
5175*da0073e9SAndroid Build Coastguard Worker        def false_fn(x):
5176*da0073e9SAndroid Build Coastguard Worker            return x.cos()
5177*da0073e9SAndroid Build Coastguard Worker
5178*da0073e9SAndroid Build Coastguard Worker        def f(pred, x):
5179*da0073e9SAndroid Build Coastguard Worker            return cond(pred, true_fn, false_fn, [x])
5180*da0073e9SAndroid Build Coastguard Worker
5181*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(f)
5182*da0073e9SAndroid Build Coastguard Worker        a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25]))
5183*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.cos(torch.tensor([0.25, 0.25])), a))
5184*da0073e9SAndroid Build Coastguard Worker        b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25]))
5185*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b))
5186*da0073e9SAndroid Build Coastguard Worker
5187*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_static(self):
5188*da0073e9SAndroid Build Coastguard Worker        # invalid size
5189*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5190*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
5191*da0073e9SAndroid Build Coastguard Worker        ):
5192*da0073e9SAndroid Build Coastguard Worker            torch.nonzero_static(torch.tensor([8]), size=-2)
5193*da0073e9SAndroid Build Coastguard Worker
5194*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5195*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
5196*da0073e9SAndroid Build Coastguard Worker        ):
5197*da0073e9SAndroid Build Coastguard Worker            torch.nonzero_static(torch.tensor([8]), size=-2, out=torch.tensor(0))
5198*da0073e9SAndroid Build Coastguard Worker
5199*da0073e9SAndroid Build Coastguard Worker        # nonzero_static.out: out dtype mismatch
5200*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor([8])
5201*da0073e9SAndroid Build Coastguard Worker        static_size = 1
5202*da0073e9SAndroid Build Coastguard Worker        out_tensor = torch.empty((static_size, input_tensor.dim()), dtype=torch.float)
5203*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5204*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long"
5205*da0073e9SAndroid Build Coastguard Worker        ):
5206*da0073e9SAndroid Build Coastguard Worker            torch.nonzero_static(input_tensor, size=static_size, out=out_tensor)
5207*da0073e9SAndroid Build Coastguard Worker
5208*da0073e9SAndroid Build Coastguard Worker        # nonzero_static.out: out resize (shrink)
5209*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor([8])
5210*da0073e9SAndroid Build Coastguard Worker        static_size = 1
5211*da0073e9SAndroid Build Coastguard Worker        out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long)
5212*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5213*da0073e9SAndroid Build Coastguard Worker            same(
5214*da0073e9SAndroid Build Coastguard Worker                torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
5215*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0]),
5216*da0073e9SAndroid Build Coastguard Worker            )
5217*da0073e9SAndroid Build Coastguard Worker        )
5218*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5219*da0073e9SAndroid Build Coastguard Worker            same(
5220*da0073e9SAndroid Build Coastguard Worker                out_tensor,
5221*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0]),
5222*da0073e9SAndroid Build Coastguard Worker            )
5223*da0073e9SAndroid Build Coastguard Worker        )
5224*da0073e9SAndroid Build Coastguard Worker
5225*da0073e9SAndroid Build Coastguard Worker        # nonzero_static.out: out resize (enlarge)
5226*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor([8])
5227*da0073e9SAndroid Build Coastguard Worker        static_size = 1
5228*da0073e9SAndroid Build Coastguard Worker        out_tensor = torch.empty((0), dtype=torch.long)
5229*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5230*da0073e9SAndroid Build Coastguard Worker            same(
5231*da0073e9SAndroid Build Coastguard Worker                torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
5232*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0]),
5233*da0073e9SAndroid Build Coastguard Worker            )
5234*da0073e9SAndroid Build Coastguard Worker        )
5235*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5236*da0073e9SAndroid Build Coastguard Worker            same(
5237*da0073e9SAndroid Build Coastguard Worker                out_tensor,
5238*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0]),
5239*da0073e9SAndroid Build Coastguard Worker            )
5240*da0073e9SAndroid Build Coastguard Worker        )
5241*da0073e9SAndroid Build Coastguard Worker
5242*da0073e9SAndroid Build Coastguard Worker        # 0 rank
5243*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor(6)
5244*da0073e9SAndroid Build Coastguard Worker        static_size = 2
5245*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5246*da0073e9SAndroid Build Coastguard Worker            same(
5247*da0073e9SAndroid Build Coastguard Worker                torch.nonzero_static(input_tensor, size=static_size),
5248*da0073e9SAndroid Build Coastguard Worker                torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
5249*da0073e9SAndroid Build Coastguard Worker            )
5250*da0073e9SAndroid Build Coastguard Worker        )
5251*da0073e9SAndroid Build Coastguard Worker
5252*da0073e9SAndroid Build Coastguard Worker        # 0 size
5253*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor([[[1]]])
5254*da0073e9SAndroid Build Coastguard Worker        static_size = 0
5255*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5256*da0073e9SAndroid Build Coastguard Worker            same(
5257*da0073e9SAndroid Build Coastguard Worker                torch.nonzero_static(input_tensor, size=static_size),
5258*da0073e9SAndroid Build Coastguard Worker                torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
5259*da0073e9SAndroid Build Coastguard Worker            )
5260*da0073e9SAndroid Build Coastguard Worker        )
5261*da0073e9SAndroid Build Coastguard Worker
5262*da0073e9SAndroid Build Coastguard Worker        # 1D input
5263*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor([0, 8])
5264*da0073e9SAndroid Build Coastguard Worker        static_size = 1
5265*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5266*da0073e9SAndroid Build Coastguard Worker            same(
5267*da0073e9SAndroid Build Coastguard Worker                torch.nonzero_static(input_tensor, size=static_size),
5268*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1]),
5269*da0073e9SAndroid Build Coastguard Worker            )
5270*da0073e9SAndroid Build Coastguard Worker        )
5271*da0073e9SAndroid Build Coastguard Worker
5272*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor([8, 0])
5273*da0073e9SAndroid Build Coastguard Worker        static_size = 2
5274*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5275*da0073e9SAndroid Build Coastguard Worker            same(
5276*da0073e9SAndroid Build Coastguard Worker                torch.nonzero_static(input_tensor, size=static_size),
5277*da0073e9SAndroid Build Coastguard Worker                torch.tensor([[0], [-1]]),  # padded with default fill_value "-1"
5278*da0073e9SAndroid Build Coastguard Worker            )
5279*da0073e9SAndroid Build Coastguard Worker        )
5280*da0073e9SAndroid Build Coastguard Worker
5281*da0073e9SAndroid Build Coastguard Worker        # 2D input
5282*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
5283*da0073e9SAndroid Build Coastguard Worker        static_size = 5
5284*da0073e9SAndroid Build Coastguard Worker        fill_value = -100
5285*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5286*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.utils.same(
5287*da0073e9SAndroid Build Coastguard Worker                torch.nonzero_static(
5288*da0073e9SAndroid Build Coastguard Worker                    input_tensor, size=static_size, fill_value=fill_value
5289*da0073e9SAndroid Build Coastguard Worker                ),
5290*da0073e9SAndroid Build Coastguard Worker                torch.tensor(
5291*da0073e9SAndroid Build Coastguard Worker                    [
5292*da0073e9SAndroid Build Coastguard Worker                        [0, 0],
5293*da0073e9SAndroid Build Coastguard Worker                        [1, 0],
5294*da0073e9SAndroid Build Coastguard Worker                        [1, 1],
5295*da0073e9SAndroid Build Coastguard Worker                        [fill_value, fill_value],
5296*da0073e9SAndroid Build Coastguard Worker                        [fill_value, fill_value],
5297*da0073e9SAndroid Build Coastguard Worker                    ]
5298*da0073e9SAndroid Build Coastguard Worker                ),
5299*da0073e9SAndroid Build Coastguard Worker            )
5300*da0073e9SAndroid Build Coastguard Worker        )
5301*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
5302*da0073e9SAndroid Build Coastguard Worker        static_size = 2
5303*da0073e9SAndroid Build Coastguard Worker        fill_value = -100
5304*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5305*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.utils.same(
5306*da0073e9SAndroid Build Coastguard Worker                torch.nonzero_static(
5307*da0073e9SAndroid Build Coastguard Worker                    input_tensor, size=static_size, fill_value=fill_value
5308*da0073e9SAndroid Build Coastguard Worker                ),
5309*da0073e9SAndroid Build Coastguard Worker                torch.tensor([[0, 0], [1, 0]]),
5310*da0073e9SAndroid Build Coastguard Worker            )
5311*da0073e9SAndroid Build Coastguard Worker        )
5312*da0073e9SAndroid Build Coastguard Worker
5313*da0073e9SAndroid Build Coastguard Worker        # 3D input
5314*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]])
5315*da0073e9SAndroid Build Coastguard Worker        static_size = 4
5316*da0073e9SAndroid Build Coastguard Worker        fill_value = -999
5317*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5318*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.utils.same(
5319*da0073e9SAndroid Build Coastguard Worker                torch.nonzero_static(
5320*da0073e9SAndroid Build Coastguard Worker                    input_tensor,
5321*da0073e9SAndroid Build Coastguard Worker                    size=static_size,
5322*da0073e9SAndroid Build Coastguard Worker                    fill_value=fill_value,
5323*da0073e9SAndroid Build Coastguard Worker                ),
5324*da0073e9SAndroid Build Coastguard Worker                torch.tensor(
5325*da0073e9SAndroid Build Coastguard Worker                    [
5326*da0073e9SAndroid Build Coastguard Worker                        [0, 1, 1],
5327*da0073e9SAndroid Build Coastguard Worker                        [1, 1, 0],
5328*da0073e9SAndroid Build Coastguard Worker                        [fill_value, fill_value, fill_value],
5329*da0073e9SAndroid Build Coastguard Worker                        [fill_value, fill_value, fill_value],
5330*da0073e9SAndroid Build Coastguard Worker                    ]
5331*da0073e9SAndroid Build Coastguard Worker                ),
5332*da0073e9SAndroid Build Coastguard Worker            )
5333*da0073e9SAndroid Build Coastguard Worker        )
5334*da0073e9SAndroid Build Coastguard Worker
5335*da0073e9SAndroid Build Coastguard Worker    def test_cond_with_quantization(self):
5336*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import cond
5337*da0073e9SAndroid Build Coastguard Worker
5338*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
5339*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
5340*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5341*da0073e9SAndroid Build Coastguard Worker                example_inputs = (torch.randn(5, 5),)
5342*da0073e9SAndroid Build Coastguard Worker                self.model = torch.nn.Linear(5, 5)
5343*da0073e9SAndroid Build Coastguard Worker                self.quantized_model = prepare_qat_fx(
5344*da0073e9SAndroid Build Coastguard Worker                    self.model, qconfig_dict, example_inputs=example_inputs
5345*da0073e9SAndroid Build Coastguard Worker                )
5346*da0073e9SAndroid Build Coastguard Worker
5347*da0073e9SAndroid Build Coastguard Worker            def forward(self, pred, x):
5348*da0073e9SAndroid Build Coastguard Worker                def true_fn(x):
5349*da0073e9SAndroid Build Coastguard Worker                    return x.sin() + self.quantized_model(x)
5350*da0073e9SAndroid Build Coastguard Worker
5351*da0073e9SAndroid Build Coastguard Worker                def false_fn(x):
5352*da0073e9SAndroid Build Coastguard Worker                    return x.cos() + self.model(x)
5353*da0073e9SAndroid Build Coastguard Worker
5354*da0073e9SAndroid Build Coastguard Worker                return cond(pred, true_fn, false_fn, [x])
5355*da0073e9SAndroid Build Coastguard Worker
5356*da0073e9SAndroid Build Coastguard Worker        module = MyModule()
5357*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager", nopython=True)(module)
5358*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((5, 5))
5359*da0073e9SAndroid Build Coastguard Worker        pred = torch.tensor(True)
5360*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(module(pred, x), opt_m(pred, x)))
5361*da0073e9SAndroid Build Coastguard Worker        pred = torch.tensor(False)
5362*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(module(pred, x), opt_m(pred, x)))
5363*da0073e9SAndroid Build Coastguard Worker
5364*da0073e9SAndroid Build Coastguard Worker    def test_map_with_quantization(self):
5365*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import map
5366*da0073e9SAndroid Build Coastguard Worker
5367*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
5368*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
5369*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5370*da0073e9SAndroid Build Coastguard Worker                example_inputs = (torch.randn(5, 5),)
5371*da0073e9SAndroid Build Coastguard Worker                self.model = torch.nn.Linear(5, 5)
5372*da0073e9SAndroid Build Coastguard Worker                self.quantized_model = prepare_qat_fx(
5373*da0073e9SAndroid Build Coastguard Worker                    self.model, qconfig_dict, example_inputs=example_inputs
5374*da0073e9SAndroid Build Coastguard Worker                )
5375*da0073e9SAndroid Build Coastguard Worker
5376*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5377*da0073e9SAndroid Build Coastguard Worker                def body(x):
5378*da0073e9SAndroid Build Coastguard Worker                    return x.sin() + self.quantized_model(x)
5379*da0073e9SAndroid Build Coastguard Worker
5380*da0073e9SAndroid Build Coastguard Worker                return map(body, x)
5381*da0073e9SAndroid Build Coastguard Worker
5382*da0073e9SAndroid Build Coastguard Worker        module = MyModule()
5383*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager", nopython=True)(module)
5384*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((5, 5))
5385*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(module(x), opt_m(x)))
5386*da0073e9SAndroid Build Coastguard Worker
5387*da0073e9SAndroid Build Coastguard Worker    def test_cond_side_effects(self):
5388*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import cond
5389*da0073e9SAndroid Build Coastguard Worker
5390*da0073e9SAndroid Build Coastguard Worker        c = 0
5391*da0073e9SAndroid Build Coastguard Worker
5392*da0073e9SAndroid Build Coastguard Worker        def true_fn(x):
5393*da0073e9SAndroid Build Coastguard Worker            return x - c
5394*da0073e9SAndroid Build Coastguard Worker
5395*da0073e9SAndroid Build Coastguard Worker        def false_fn(x):
5396*da0073e9SAndroid Build Coastguard Worker            return x + c
5397*da0073e9SAndroid Build Coastguard Worker
5398*da0073e9SAndroid Build Coastguard Worker        def f(pred, x):
5399*da0073e9SAndroid Build Coastguard Worker            nonlocal c
5400*da0073e9SAndroid Build Coastguard Worker            c = 1
5401*da0073e9SAndroid Build Coastguard Worker            return cond(pred, true_fn, false_fn, [x])
5402*da0073e9SAndroid Build Coastguard Worker
5403*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(f)
5404*da0073e9SAndroid Build Coastguard Worker        c = 0
5405*da0073e9SAndroid Build Coastguard Worker        a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25]))
5406*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.tensor([1.25, 1.25]), a))
5407*da0073e9SAndroid Build Coastguard Worker
5408*da0073e9SAndroid Build Coastguard Worker    def test_map_side_effects(self):
5409*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import map
5410*da0073e9SAndroid Build Coastguard Worker
5411*da0073e9SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
5412*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
5413*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5414*da0073e9SAndroid Build Coastguard Worker                self.w = torch.tensor(1)
5415*da0073e9SAndroid Build Coastguard Worker
5416*da0073e9SAndroid Build Coastguard Worker            def forward(self, xs):
5417*da0073e9SAndroid Build Coastguard Worker                def body(x):
5418*da0073e9SAndroid Build Coastguard Worker                    self.w += 1
5419*da0073e9SAndroid Build Coastguard Worker                    return x
5420*da0073e9SAndroid Build Coastguard Worker
5421*da0073e9SAndroid Build Coastguard Worker                return map(body, xs)
5422*da0073e9SAndroid Build Coastguard Worker
5423*da0073e9SAndroid Build Coastguard Worker        mod = Module()
5424*da0073e9SAndroid Build Coastguard Worker
5425*da0073e9SAndroid Build Coastguard Worker        error_message = ""
5426*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.inline_inbuilt_nn_modules:
5427*da0073e9SAndroid Build Coastguard Worker            error_message = r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)"
5428*da0073e9SAndroid Build Coastguard Worker        else:
5429*da0073e9SAndroid Build Coastguard Worker            error_message = "Can't inplace modify module params/buffers"
5430*da0073e9SAndroid Build Coastguard Worker
5431*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Unsupported, error_message):
5432*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize("eager", nopython=True)(mod)
5433*da0073e9SAndroid Build Coastguard Worker            opt_fn(torch.randn(3, 2))
5434*da0073e9SAndroid Build Coastguard Worker
5435*da0073e9SAndroid Build Coastguard Worker    def test_cond_nested(self):
5436*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import cond
5437*da0073e9SAndroid Build Coastguard Worker
5438*da0073e9SAndroid Build Coastguard Worker        def true_fn_nested(x):
5439*da0073e9SAndroid Build Coastguard Worker            return x * 10
5440*da0073e9SAndroid Build Coastguard Worker
5441*da0073e9SAndroid Build Coastguard Worker        def false_fn_nested(x):
5442*da0073e9SAndroid Build Coastguard Worker            return x * -1
5443*da0073e9SAndroid Build Coastguard Worker
5444*da0073e9SAndroid Build Coastguard Worker        def true_fn(pred2, x):
5445*da0073e9SAndroid Build Coastguard Worker            return x.sin()
5446*da0073e9SAndroid Build Coastguard Worker
5447*da0073e9SAndroid Build Coastguard Worker        def false_fn(pred2, x):
5448*da0073e9SAndroid Build Coastguard Worker            return x + cond(pred2, true_fn_nested, false_fn_nested, [x])
5449*da0073e9SAndroid Build Coastguard Worker
5450*da0073e9SAndroid Build Coastguard Worker        def f(pred, pred2, x):
5451*da0073e9SAndroid Build Coastguard Worker            return cond(pred, true_fn, false_fn, [pred2, x])
5452*da0073e9SAndroid Build Coastguard Worker
5453*da0073e9SAndroid Build Coastguard Worker        cc = torch._dynamo.testing.CompileCounter()
5454*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cc)(f)
5455*da0073e9SAndroid Build Coastguard Worker        true_true_sin = opt_fn(
5456*da0073e9SAndroid Build Coastguard Worker            torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
5457*da0073e9SAndroid Build Coastguard Worker        )
5458*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin))
5459*da0073e9SAndroid Build Coastguard Worker
5460*da0073e9SAndroid Build Coastguard Worker        true_false_sin = opt_fn(
5461*da0073e9SAndroid Build Coastguard Worker            torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25])
5462*da0073e9SAndroid Build Coastguard Worker        )
5463*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin))
5464*da0073e9SAndroid Build Coastguard Worker
5465*da0073e9SAndroid Build Coastguard Worker        false_true_sum_mult = opt_fn(
5466*da0073e9SAndroid Build Coastguard Worker            torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
5467*da0073e9SAndroid Build Coastguard Worker        )
5468*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5469*da0073e9SAndroid Build Coastguard Worker            same(torch.tensor([2.75, 2.75]), false_true_sum_mult)
5470*da0073e9SAndroid Build Coastguard Worker        )  # * 10 then add x
5471*da0073e9SAndroid Build Coastguard Worker
5472*da0073e9SAndroid Build Coastguard Worker        false_false_sum_neg = opt_fn(
5473*da0073e9SAndroid Build Coastguard Worker            torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25])
5474*da0073e9SAndroid Build Coastguard Worker        )
5475*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5476*da0073e9SAndroid Build Coastguard Worker            same(torch.tensor([0.0, 0.0]), false_false_sum_neg)
5477*da0073e9SAndroid Build Coastguard Worker        )  # * -1 then add x
5478*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cc.frame_count, 2)
5479*da0073e9SAndroid Build Coastguard Worker
5480*da0073e9SAndroid Build Coastguard Worker    def test_cond_export(self):
5481*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import cond
5482*da0073e9SAndroid Build Coastguard Worker
5483*da0073e9SAndroid Build Coastguard Worker        def true_fn_nested(x):
5484*da0073e9SAndroid Build Coastguard Worker            return x * 10
5485*da0073e9SAndroid Build Coastguard Worker
5486*da0073e9SAndroid Build Coastguard Worker        def false_fn_nested(x):
5487*da0073e9SAndroid Build Coastguard Worker            return x * -1
5488*da0073e9SAndroid Build Coastguard Worker
5489*da0073e9SAndroid Build Coastguard Worker        def true_fn(pred2, x):
5490*da0073e9SAndroid Build Coastguard Worker            return x.sin()
5491*da0073e9SAndroid Build Coastguard Worker
5492*da0073e9SAndroid Build Coastguard Worker        def false_fn(pred2, x):
5493*da0073e9SAndroid Build Coastguard Worker            return x + cond(pred2, true_fn_nested, false_fn_nested, [x])
5494*da0073e9SAndroid Build Coastguard Worker
5495*da0073e9SAndroid Build Coastguard Worker        def f(pred, pred2, x):
5496*da0073e9SAndroid Build Coastguard Worker            return cond(pred, true_fn, false_fn, [pred2, x])
5497*da0073e9SAndroid Build Coastguard Worker
5498*da0073e9SAndroid Build Coastguard Worker        graph, guard = torch._dynamo.export(f)(
5499*da0073e9SAndroid Build Coastguard Worker            torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
5500*da0073e9SAndroid Build Coastguard Worker        )
5501*da0073e9SAndroid Build Coastguard Worker        true_true_sin = graph(
5502*da0073e9SAndroid Build Coastguard Worker            torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
5503*da0073e9SAndroid Build Coastguard Worker        )
5504*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin))
5505*da0073e9SAndroid Build Coastguard Worker
5506*da0073e9SAndroid Build Coastguard Worker        true_false_sin = graph(
5507*da0073e9SAndroid Build Coastguard Worker            torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25])
5508*da0073e9SAndroid Build Coastguard Worker        )
5509*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin))
5510*da0073e9SAndroid Build Coastguard Worker
5511*da0073e9SAndroid Build Coastguard Worker        false_true_sum_mult = graph(
5512*da0073e9SAndroid Build Coastguard Worker            torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
5513*da0073e9SAndroid Build Coastguard Worker        )
5514*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5515*da0073e9SAndroid Build Coastguard Worker            same(torch.tensor([2.75, 2.75]), false_true_sum_mult)
5516*da0073e9SAndroid Build Coastguard Worker        )  # * 10 then add x
5517*da0073e9SAndroid Build Coastguard Worker
5518*da0073e9SAndroid Build Coastguard Worker        false_false_sum_neg = graph(
5519*da0073e9SAndroid Build Coastguard Worker            torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25])
5520*da0073e9SAndroid Build Coastguard Worker        )
5521*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
5522*da0073e9SAndroid Build Coastguard Worker            same(torch.tensor([0.0, 0.0]), false_false_sum_neg)
5523*da0073e9SAndroid Build Coastguard Worker        )  # * -1 then add x
5524*da0073e9SAndroid Build Coastguard Worker
5525*da0073e9SAndroid Build Coastguard Worker    def test_cond_export_single_arg(self):
5526*da0073e9SAndroid Build Coastguard Worker        from functorch.experimental.control_flow import cond
5527*da0073e9SAndroid Build Coastguard Worker
5528*da0073e9SAndroid Build Coastguard Worker        def true_fn(x):
5529*da0073e9SAndroid Build Coastguard Worker            return x
5530*da0073e9SAndroid Build Coastguard Worker
5531*da0073e9SAndroid Build Coastguard Worker        def false_fn(x):
5532*da0073e9SAndroid Build Coastguard Worker            return x.sin()
5533*da0073e9SAndroid Build Coastguard Worker
5534*da0073e9SAndroid Build Coastguard Worker        def f(pred, x):
5535*da0073e9SAndroid Build Coastguard Worker            return cond(pred, true_fn, false_fn, [x])
5536*da0073e9SAndroid Build Coastguard Worker
5537*da0073e9SAndroid Build Coastguard Worker        graph, guard = torch._dynamo.export(f)(
5538*da0073e9SAndroid Build Coastguard Worker            torch.tensor(False), torch.tensor([0.25, 0.25])
5539*da0073e9SAndroid Build Coastguard Worker        )
5540*da0073e9SAndroid Build Coastguard Worker        true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25]))
5541*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror))
5542*da0073e9SAndroid Build Coastguard Worker        true_mirror_2 = graph(torch.tensor(True), torch.tensor([0.33, 0.33, 0.33]))
5543*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.tensor([0.33, 0.33, 0.33]), true_mirror_2))
5544*da0073e9SAndroid Build Coastguard Worker
5545*da0073e9SAndroid Build Coastguard Worker        false_sin = graph(torch.tensor(False), torch.tensor([0.5, 0.5]))
5546*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(torch.sin(torch.tensor([0.5, 0.5])), false_sin))
5547*da0073e9SAndroid Build Coastguard Worker
5548*da0073e9SAndroid Build Coastguard Worker    def test_enum_guards(self):
5549*da0073e9SAndroid Build Coastguard Worker        class MyEnum(enum.Enum):
5550*da0073e9SAndroid Build Coastguard Worker            FOO = 10
5551*da0073e9SAndroid Build Coastguard Worker            BAR = 20
5552*da0073e9SAndroid Build Coastguard Worker
5553*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
5554*da0073e9SAndroid Build Coastguard Worker            if y == MyEnum.FOO:
5555*da0073e9SAndroid Build Coastguard Worker                return x + 1
5556*da0073e9SAndroid Build Coastguard Worker            else:
5557*da0073e9SAndroid Build Coastguard Worker                return x - 1
5558*da0073e9SAndroid Build Coastguard Worker
5559*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3)
5560*da0073e9SAndroid Build Coastguard Worker        y = MyEnum.BAR
5561*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
5562*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager")(fn)
5563*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
5564*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
5565*da0073e9SAndroid Build Coastguard Worker
5566*da0073e9SAndroid Build Coastguard Worker    def test_duplicate_graph_break_log(self):
5567*da0073e9SAndroid Build Coastguard Worker        torch._logging.set_logs(graph_breaks=True)
5568*da0073e9SAndroid Build Coastguard Worker
5569*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
5570*da0073e9SAndroid Build Coastguard Worker        def f1(a, b):
5571*da0073e9SAndroid Build Coastguard Worker            f2(a, b)
5572*da0073e9SAndroid Build Coastguard Worker
5573*da0073e9SAndroid Build Coastguard Worker        def f2(a, b):
5574*da0073e9SAndroid Build Coastguard Worker            c = a + b
5575*da0073e9SAndroid Build Coastguard Worker            print("break")
5576*da0073e9SAndroid Build Coastguard Worker            return a + b + c
5577*da0073e9SAndroid Build Coastguard Worker
5578*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
5579*da0073e9SAndroid Build Coastguard Worker        def g1(a, b):
5580*da0073e9SAndroid Build Coastguard Worker            g2(a, b)
5581*da0073e9SAndroid Build Coastguard Worker
5582*da0073e9SAndroid Build Coastguard Worker        def g2(a, b):
5583*da0073e9SAndroid Build Coastguard Worker            c = a + b
5584*da0073e9SAndroid Build Coastguard Worker            print("break")
5585*da0073e9SAndroid Build Coastguard Worker            return a + b + c
5586*da0073e9SAndroid Build Coastguard Worker
5587*da0073e9SAndroid Build Coastguard Worker        def count_graph_break_msgs(msgs):
5588*da0073e9SAndroid Build Coastguard Worker            return sum(msg.find("Graph break") != -1 for msg in msgs)
5589*da0073e9SAndroid Build Coastguard Worker
5590*da0073e9SAndroid Build Coastguard Worker        with self.assertLogs(
5591*da0073e9SAndroid Build Coastguard Worker            logger="torch._dynamo", level=logging.DEBUG
5592*da0073e9SAndroid Build Coastguard Worker        ) as log, torch._dynamo.config.patch(verbose=True):
5593*da0073e9SAndroid Build Coastguard Worker            f1(torch.randn(10), torch.randn(10))
5594*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(count_graph_break_msgs(log.output), 1)
5595*da0073e9SAndroid Build Coastguard Worker
5596*da0073e9SAndroid Build Coastguard Worker        with self.assertLogs(
5597*da0073e9SAndroid Build Coastguard Worker            logger="torch._dynamo", level=logging.DEBUG
5598*da0073e9SAndroid Build Coastguard Worker        ) as log, torch._dynamo.config.patch(verbose=False):
5599*da0073e9SAndroid Build Coastguard Worker            g1(torch.randn(10), torch.randn(10))
5600*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(count_graph_break_msgs(log.output), 1)
5601*da0073e9SAndroid Build Coastguard Worker
5602*da0073e9SAndroid Build Coastguard Worker        # reset logging state
5603*da0073e9SAndroid Build Coastguard Worker        torch._logging.set_logs()
5604*da0073e9SAndroid Build Coastguard Worker
5605*da0073e9SAndroid Build Coastguard Worker    def test_inplace_param_update(self):
5606*da0073e9SAndroid Build Coastguard Worker        def fn(param, y):
5607*da0073e9SAndroid Build Coastguard Worker            prev_grad = torch.is_grad_enabled()
5608*da0073e9SAndroid Build Coastguard Worker            try:
5609*da0073e9SAndroid Build Coastguard Worker                torch.set_grad_enabled(False)
5610*da0073e9SAndroid Build Coastguard Worker                torch.set_grad_enabled(True)
5611*da0073e9SAndroid Build Coastguard Worker                torch.set_grad_enabled(False)
5612*da0073e9SAndroid Build Coastguard Worker                param.add_(y)
5613*da0073e9SAndroid Build Coastguard Worker            finally:
5614*da0073e9SAndroid Build Coastguard Worker                torch.set_grad_enabled(prev_grad)
5615*da0073e9SAndroid Build Coastguard Worker
5616*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4)
5617*da0073e9SAndroid Build Coastguard Worker        x = torch.nn.Parameter(torch.randn(4))
5618*da0073e9SAndroid Build Coastguard Worker        fn(x, y)
5619*da0073e9SAndroid Build Coastguard Worker
5620*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
5621*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
5622*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, y)
5623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
5624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 3)
5625*da0073e9SAndroid Build Coastguard Worker
5626*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
5627*da0073e9SAndroid Build Coastguard Worker        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
5628*da0073e9SAndroid Build Coastguard Worker        "Can't run fused SDPA on this platform",
5629*da0073e9SAndroid Build Coastguard Worker    )
5630*da0073e9SAndroid Build Coastguard Worker    def test_parsing_sdpa(self):
5631*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
5632*da0073e9SAndroid Build Coastguard Worker            def forward(self, query, key, value):
5633*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(query, key, value, None, 0, True)
5634*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(
5635*da0073e9SAndroid Build Coastguard Worker                    query, key, value, None, 0, True, scale=8
5636*da0073e9SAndroid Build Coastguard Worker                )
5637*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(
5638*da0073e9SAndroid Build Coastguard Worker                    query=query,
5639*da0073e9SAndroid Build Coastguard Worker                    key=key,
5640*da0073e9SAndroid Build Coastguard Worker                    value=value,
5641*da0073e9SAndroid Build Coastguard Worker                    attn_mask=None,
5642*da0073e9SAndroid Build Coastguard Worker                    dropout_p=0,
5643*da0073e9SAndroid Build Coastguard Worker                    is_causal=True,
5644*da0073e9SAndroid Build Coastguard Worker                )
5645*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(
5646*da0073e9SAndroid Build Coastguard Worker                    query,
5647*da0073e9SAndroid Build Coastguard Worker                    key=key,
5648*da0073e9SAndroid Build Coastguard Worker                    value=value,
5649*da0073e9SAndroid Build Coastguard Worker                    attn_mask=None,
5650*da0073e9SAndroid Build Coastguard Worker                    dropout_p=0,
5651*da0073e9SAndroid Build Coastguard Worker                    is_causal=True,
5652*da0073e9SAndroid Build Coastguard Worker                )
5653*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(
5654*da0073e9SAndroid Build Coastguard Worker                    query, key, value, None, dropout_p=0, is_causal=True
5655*da0073e9SAndroid Build Coastguard Worker                )
5656*da0073e9SAndroid Build Coastguard Worker                out = F.scaled_dot_product_attention(query, key, value, None, scale=8)
5657*da0073e9SAndroid Build Coastguard Worker                return out
5658*da0073e9SAndroid Build Coastguard Worker
5659*da0073e9SAndroid Build Coastguard Worker        device = "cuda"
5660*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float16
5661*da0073e9SAndroid Build Coastguard Worker        seq_len_q = 1
5662*da0073e9SAndroid Build Coastguard Worker        seq_len_k = 1
5663*da0073e9SAndroid Build Coastguard Worker        head_dim = 8
5664*da0073e9SAndroid Build Coastguard Worker        query = torch.ones(
5665*da0073e9SAndroid Build Coastguard Worker            1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True
5666*da0073e9SAndroid Build Coastguard Worker        )
5667*da0073e9SAndroid Build Coastguard Worker        key = torch.ones(
5668*da0073e9SAndroid Build Coastguard Worker            1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
5669*da0073e9SAndroid Build Coastguard Worker        )
5670*da0073e9SAndroid Build Coastguard Worker        value = torch.ones(
5671*da0073e9SAndroid Build Coastguard Worker            1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
5672*da0073e9SAndroid Build Coastguard Worker        )
5673*da0073e9SAndroid Build Coastguard Worker        module = MyModule()
5674*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("inductor")(module)
5675*da0073e9SAndroid Build Coastguard Worker        opt_mod(query, key, value)
5676*da0073e9SAndroid Build Coastguard Worker
5677*da0073e9SAndroid Build Coastguard Worker    def test_generate_tensor_from_list_of_numpy_primitive_type(self):
5678*da0073e9SAndroid Build Coastguard Worker        # Test sth like torch.LongTensor(list(np.int64, np.int64, ...))
5679*da0073e9SAndroid Build Coastguard Worker        def fn():
5680*da0073e9SAndroid Build Coastguard Worker            x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64)
5681*da0073e9SAndroid Build Coastguard Worker            y = [x[0], x[2], x[4]]
5682*da0073e9SAndroid Build Coastguard Worker            return torch.LongTensor(y)
5683*da0073e9SAndroid Build Coastguard Worker
5684*da0073e9SAndroid Build Coastguard Worker        ref = fn()
5685*da0073e9SAndroid Build Coastguard Worker        res = torch.compile(fullgraph=True)(fn)()
5686*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
5687*da0073e9SAndroid Build Coastguard Worker
5688*da0073e9SAndroid Build Coastguard Worker    def test_object_classmethod(self):
5689*da0073e9SAndroid Build Coastguard Worker        class C:
5690*da0073e9SAndroid Build Coastguard Worker            @classmethod
5691*da0073e9SAndroid Build Coastguard Worker            def fn(cls, x):
5692*da0073e9SAndroid Build Coastguard Worker                return x + x
5693*da0073e9SAndroid Build Coastguard Worker
5694*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
5695*da0073e9SAndroid Build Coastguard Worker        def f():
5696*da0073e9SAndroid Build Coastguard Worker            return C().fn(torch.ones(2, 3))
5697*da0073e9SAndroid Build Coastguard Worker
5698*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(f(), torch.tensor([2.0])))
5699*da0073e9SAndroid Build Coastguard Worker
5700*da0073e9SAndroid Build Coastguard Worker    def test_object_staticmethod(self):
5701*da0073e9SAndroid Build Coastguard Worker        class C:
5702*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5703*da0073e9SAndroid Build Coastguard Worker            def fn(x):
5704*da0073e9SAndroid Build Coastguard Worker                return x + x
5705*da0073e9SAndroid Build Coastguard Worker
5706*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
5707*da0073e9SAndroid Build Coastguard Worker        def f():
5708*da0073e9SAndroid Build Coastguard Worker            return C().fn(torch.ones(2, 3))
5709*da0073e9SAndroid Build Coastguard Worker
5710*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(f(), torch.tensor([2.0])))
5711*da0073e9SAndroid Build Coastguard Worker
5712*da0073e9SAndroid Build Coastguard Worker    def test_user_function_variable_supports_enum_argument(self):
5713*da0073e9SAndroid Build Coastguard Worker        class Foo(enum.Enum):
5714*da0073e9SAndroid Build Coastguard Worker            FOO = 0
5715*da0073e9SAndroid Build Coastguard Worker            BAR = 1
5716*da0073e9SAndroid Build Coastguard Worker
5717*da0073e9SAndroid Build Coastguard Worker        def gn(x, y=Foo.FOO):
5718*da0073e9SAndroid Build Coastguard Worker            if y is Foo.FOO:
5719*da0073e9SAndroid Build Coastguard Worker                return x
5720*da0073e9SAndroid Build Coastguard Worker            else:
5721*da0073e9SAndroid Build Coastguard Worker                return x + 1
5722*da0073e9SAndroid Build Coastguard Worker
5723*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5724*da0073e9SAndroid Build Coastguard Worker            return gn(x)
5725*da0073e9SAndroid Build Coastguard Worker
5726*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
5727*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
5728*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
5729*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
5730*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
5731*da0073e9SAndroid Build Coastguard Worker
5732*da0073e9SAndroid Build Coastguard Worker    def test_user_function_variable_supports_type_abcmeta_argument(self):
5733*da0073e9SAndroid Build Coastguard Worker        class Foo(metaclass=abc.ABCMeta):
5734*da0073e9SAndroid Build Coastguard Worker            @abc.abstractclassmethod
5735*da0073e9SAndroid Build Coastguard Worker            def read(self):  # noqa: B027
5736*da0073e9SAndroid Build Coastguard Worker                pass
5737*da0073e9SAndroid Build Coastguard Worker
5738*da0073e9SAndroid Build Coastguard Worker        class Bar(Foo):
5739*da0073e9SAndroid Build Coastguard Worker            def read(self):
5740*da0073e9SAndroid Build Coastguard Worker                return "Hello World!"
5741*da0073e9SAndroid Build Coastguard Worker
5742*da0073e9SAndroid Build Coastguard Worker        class Baz:
5743*da0073e9SAndroid Build Coastguard Worker            pass
5744*da0073e9SAndroid Build Coastguard Worker
5745*da0073e9SAndroid Build Coastguard Worker        def gn(x, tys=(Bar, Baz)):
5746*da0073e9SAndroid Build Coastguard Worker            if Bar in tys:
5747*da0073e9SAndroid Build Coastguard Worker                return x - 1
5748*da0073e9SAndroid Build Coastguard Worker            else:
5749*da0073e9SAndroid Build Coastguard Worker                return x + 1
5750*da0073e9SAndroid Build Coastguard Worker
5751*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5752*da0073e9SAndroid Build Coastguard Worker            return gn(x)
5753*da0073e9SAndroid Build Coastguard Worker
5754*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
5755*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
5756*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
5757*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
5758*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
5759*da0073e9SAndroid Build Coastguard Worker
5760*da0073e9SAndroid Build Coastguard Worker    def test_user_function_variable_supports_function_argument(self):
5761*da0073e9SAndroid Build Coastguard Worker        # Test user defined function default arguments can be:
5762*da0073e9SAndroid Build Coastguard Worker        # 1, user defined functions (e.g, add1)
5763*da0073e9SAndroid Build Coastguard Worker        # 2, torch functions (e.g, torch.sin)
5764*da0073e9SAndroid Build Coastguard Worker        # 3, python builtin functions (e.g, operator.neg)
5765*da0073e9SAndroid Build Coastguard Worker        def add1(x):
5766*da0073e9SAndroid Build Coastguard Worker            return x + 1
5767*da0073e9SAndroid Build Coastguard Worker
5768*da0073e9SAndroid Build Coastguard Worker        def gn(x, f1=add1, f2=torch.sin, f3=operator.neg):
5769*da0073e9SAndroid Build Coastguard Worker            return f3(f2(f1(x)))
5770*da0073e9SAndroid Build Coastguard Worker
5771*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5772*da0073e9SAndroid Build Coastguard Worker            return gn(x)
5773*da0073e9SAndroid Build Coastguard Worker
5774*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
5775*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
5776*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
5777*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
5778*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
5779*da0073e9SAndroid Build Coastguard Worker
5780*da0073e9SAndroid Build Coastguard Worker    def test_typing_variable_isinstance(self):
5781*da0073e9SAndroid Build Coastguard Worker        def fn(x, m):
5782*da0073e9SAndroid Build Coastguard Worker            if isinstance(m, typing.Mapping):
5783*da0073e9SAndroid Build Coastguard Worker                return x + 1
5784*da0073e9SAndroid Build Coastguard Worker            else:
5785*da0073e9SAndroid Build Coastguard Worker                return x - 1
5786*da0073e9SAndroid Build Coastguard Worker
5787*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
5788*da0073e9SAndroid Build Coastguard Worker        m = {"x": torch.randn(3)}
5789*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, m)
5790*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
5791*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, m)
5792*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref, res))
5793*da0073e9SAndroid Build Coastguard Worker
5794*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
5795*da0073e9SAndroid Build Coastguard Worker    def test_repro_graph_breaks_in__get_item_by_idx(self):
5796*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
5797*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
5798*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5799*da0073e9SAndroid Build Coastguard Worker                self.mod = torch.nn.Sequential(
5800*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(3, 3), torch.nn.Linear(3, 3)
5801*da0073e9SAndroid Build Coastguard Worker                )
5802*da0073e9SAndroid Build Coastguard Worker
5803*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5804*da0073e9SAndroid Build Coastguard Worker                return self.mod[0](x)
5805*da0073e9SAndroid Build Coastguard Worker
5806*da0073e9SAndroid Build Coastguard Worker        m = Mod()
5807*da0073e9SAndroid Build Coastguard Worker        graph, _ = torch._dynamo.export(m)(torch.randn(3, 3))
5808*da0073e9SAndroid Build Coastguard Worker
5809*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
5810*da0073e9SAndroid Build Coastguard Worker    def test_nn_sequential_invocation(self):
5811*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
5812*da0073e9SAndroid Build Coastguard Worker
5813*da0073e9SAndroid Build Coastguard Worker            class TestModel(torch.nn.Module):
5814*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
5815*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
5816*da0073e9SAndroid Build Coastguard Worker                    self.linears = torch.nn.Sequential(
5817*da0073e9SAndroid Build Coastguard Worker                        torch.nn.Linear(2, 2),
5818*da0073e9SAndroid Build Coastguard Worker                        torch.nn.Linear(2, 2),
5819*da0073e9SAndroid Build Coastguard Worker                        torch.nn.Linear(2, 2),
5820*da0073e9SAndroid Build Coastguard Worker                        torch.nn.Linear(2, 2),
5821*da0073e9SAndroid Build Coastguard Worker                    )
5822*da0073e9SAndroid Build Coastguard Worker
5823*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
5824*da0073e9SAndroid Build Coastguard Worker                    all_but_last = self.linears[:-1]
5825*da0073e9SAndroid Build Coastguard Worker                    return all_but_last(x)
5826*da0073e9SAndroid Build Coastguard Worker
5827*da0073e9SAndroid Build Coastguard Worker            m = TestModel()
5828*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((2, 2))
5829*da0073e9SAndroid Build Coastguard Worker            real = m(x)
5830*da0073e9SAndroid Build Coastguard Worker            graph, _ = torch._dynamo.export(m)(x)
5831*da0073e9SAndroid Build Coastguard Worker            dynamo_result = graph(x)
5832*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(real, dynamo_result))
5833*da0073e9SAndroid Build Coastguard Worker
5834*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
5835*da0073e9SAndroid Build Coastguard Worker    def test_nn_sequential_invocation_reposition_indices(self):
5836*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
5837*da0073e9SAndroid Build Coastguard Worker
5838*da0073e9SAndroid Build Coastguard Worker            class TestModel(torch.nn.Module):
5839*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
5840*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
5841*da0073e9SAndroid Build Coastguard Worker                    self.linears = torch.nn.Sequential(
5842*da0073e9SAndroid Build Coastguard Worker                        torch.nn.Linear(2, 2),
5843*da0073e9SAndroid Build Coastguard Worker                        torch.nn.Linear(2, 2),
5844*da0073e9SAndroid Build Coastguard Worker                        torch.nn.Linear(2, 2),
5845*da0073e9SAndroid Build Coastguard Worker                        torch.nn.Linear(2, 2),
5846*da0073e9SAndroid Build Coastguard Worker                    )
5847*da0073e9SAndroid Build Coastguard Worker
5848*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
5849*da0073e9SAndroid Build Coastguard Worker                    all_but_last = self.linears[1:3]
5850*da0073e9SAndroid Build Coastguard Worker                    return all_but_last(x)
5851*da0073e9SAndroid Build Coastguard Worker
5852*da0073e9SAndroid Build Coastguard Worker            m = TestModel()
5853*da0073e9SAndroid Build Coastguard Worker            x = torch.rand((2, 2))
5854*da0073e9SAndroid Build Coastguard Worker            real = m(x)
5855*da0073e9SAndroid Build Coastguard Worker            graph, _ = torch._dynamo.export(m)(x)
5856*da0073e9SAndroid Build Coastguard Worker            dynamo_result = graph(x)
5857*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(real, dynamo_result))
5858*da0073e9SAndroid Build Coastguard Worker
5859*da0073e9SAndroid Build Coastguard Worker    def test_error_on_nested_fx_trace(self):
5860*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(2, 3)
5861*da0073e9SAndroid Build Coastguard Worker
5862*da0073e9SAndroid Build Coastguard Worker        def f(x):
5863*da0073e9SAndroid Build Coastguard Worker            x + x
5864*da0073e9SAndroid Build Coastguard Worker
5865*da0073e9SAndroid Build Coastguard Worker        real = f(input)
5866*da0073e9SAndroid Build Coastguard Worker
5867*da0073e9SAndroid Build Coastguard Worker        optimized = torch._dynamo.optimize("eager")(f)
5868*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(optimized(input), real))
5869*da0073e9SAndroid Build Coastguard Worker
5870*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"):
5871*da0073e9SAndroid Build Coastguard Worker            gm = torch.fx.symbolic_trace(optimized)
5872*da0073e9SAndroid Build Coastguard Worker
5873*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False)
5874*da0073e9SAndroid Build Coastguard Worker    def test_no_error_on_nested_fx_trace(self):
5875*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(2, 3)
5876*da0073e9SAndroid Build Coastguard Worker
5877*da0073e9SAndroid Build Coastguard Worker        def f(x):
5878*da0073e9SAndroid Build Coastguard Worker            x + x
5879*da0073e9SAndroid Build Coastguard Worker
5880*da0073e9SAndroid Build Coastguard Worker        real = f(input)
5881*da0073e9SAndroid Build Coastguard Worker
5882*da0073e9SAndroid Build Coastguard Worker        optimized = torch._dynamo.optimize("eager")(f)
5883*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(optimized(input), real))
5884*da0073e9SAndroid Build Coastguard Worker
5885*da0073e9SAndroid Build Coastguard Worker        # should not error
5886*da0073e9SAndroid Build Coastguard Worker        gm = torch.fx.symbolic_trace(optimized)
5887*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(gm(input), real))
5888*da0073e9SAndroid Build Coastguard Worker
5889*da0073e9SAndroid Build Coastguard Worker    def test_not_dynamic_scope(self):
5890*da0073e9SAndroid Build Coastguard Worker        def f(y):
5891*da0073e9SAndroid Build Coastguard Worker            x = 1
5892*da0073e9SAndroid Build Coastguard Worker
5893*da0073e9SAndroid Build Coastguard Worker            def g():
5894*da0073e9SAndroid Build Coastguard Worker                x = 2
5895*da0073e9SAndroid Build Coastguard Worker                return lambda: x
5896*da0073e9SAndroid Build Coastguard Worker
5897*da0073e9SAndroid Build Coastguard Worker            return y + g()()
5898*da0073e9SAndroid Build Coastguard Worker
5899*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros(1)
5900*da0073e9SAndroid Build Coastguard Worker        real = f(input)
5901*da0073e9SAndroid Build Coastguard Worker        optimized = torch._dynamo.optimize("eager")(f)
5902*da0073e9SAndroid Build Coastguard Worker        opt = optimized(input)
5903*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt, real))
5904*da0073e9SAndroid Build Coastguard Worker
5905*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode(self):
5906*da0073e9SAndroid Build Coastguard Worker        @torch.inference_mode()
5907*da0073e9SAndroid Build Coastguard Worker        def func(x, y):
5908*da0073e9SAndroid Build Coastguard Worker            return x.add(1.0) + y
5909*da0073e9SAndroid Build Coastguard Worker
5910*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4, requires_grad=True)
5911*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(4, requires_grad=True)
5912*da0073e9SAndroid Build Coastguard Worker        ref = func(x, y)
5913*da0073e9SAndroid Build Coastguard Worker        opt_func = torch._dynamo.optimize("eager")(func)
5914*da0073e9SAndroid Build Coastguard Worker
5915*da0073e9SAndroid Build Coastguard Worker        x1 = torch.ones(4, requires_grad=True)
5916*da0073e9SAndroid Build Coastguard Worker        res = opt_func(x1, y)
5917*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
5918*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(x, x1))
5919*da0073e9SAndroid Build Coastguard Worker
5920*da0073e9SAndroid Build Coastguard Worker    def test_if_cond_nn_mod1(self):
5921*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
5922*da0073e9SAndroid Build Coastguard Worker            def __init__(self, output_relu=True):
5923*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5924*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU() if output_relu else None
5925*da0073e9SAndroid Build Coastguard Worker
5926*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5927*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
5928*da0073e9SAndroid Build Coastguard Worker                if self.relu:
5929*da0073e9SAndroid Build Coastguard Worker                    x = self.relu(x)
5930*da0073e9SAndroid Build Coastguard Worker                return x
5931*da0073e9SAndroid Build Coastguard Worker
5932*da0073e9SAndroid Build Coastguard Worker        model = MockModule()
5933*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
5934*da0073e9SAndroid Build Coastguard Worker
5935*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
5936*da0073e9SAndroid Build Coastguard Worker        ref = model(x)
5937*da0073e9SAndroid Build Coastguard Worker        res = opt_model(x)
5938*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
5939*da0073e9SAndroid Build Coastguard Worker
5940*da0073e9SAndroid Build Coastguard Worker        model = MockModule(output_relu=False)
5941*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
5942*da0073e9SAndroid Build Coastguard Worker
5943*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
5944*da0073e9SAndroid Build Coastguard Worker        ref = model(x)
5945*da0073e9SAndroid Build Coastguard Worker        res = opt_model(x)
5946*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
5947*da0073e9SAndroid Build Coastguard Worker
5948*da0073e9SAndroid Build Coastguard Worker    def test_if_cond_nn_mod2(self):
5949*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
5950*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
5951*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5952*da0073e9SAndroid Build Coastguard Worker                self.layer = torch.nn.Sequential()
5953*da0073e9SAndroid Build Coastguard Worker
5954*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5955*da0073e9SAndroid Build Coastguard Worker                if self.layer:
5956*da0073e9SAndroid Build Coastguard Worker                    return x + 1
5957*da0073e9SAndroid Build Coastguard Worker                else:
5958*da0073e9SAndroid Build Coastguard Worker                    return x - 1
5959*da0073e9SAndroid Build Coastguard Worker
5960*da0073e9SAndroid Build Coastguard Worker        model = MockModule()
5961*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
5962*da0073e9SAndroid Build Coastguard Worker        ref = model(x)
5963*da0073e9SAndroid Build Coastguard Worker        opt_model = torch.compile(backend="eager")(model)
5964*da0073e9SAndroid Build Coastguard Worker        res = opt_model(x)
5965*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
5966*da0073e9SAndroid Build Coastguard Worker
5967*da0073e9SAndroid Build Coastguard Worker    def test_if_cond_nn_mod3(self):
5968*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5969*da0073e9SAndroid Build Coastguard Worker            if torch.nn.ModuleList():
5970*da0073e9SAndroid Build Coastguard Worker                return x + 1
5971*da0073e9SAndroid Build Coastguard Worker            else:
5972*da0073e9SAndroid Build Coastguard Worker                return x - 1
5973*da0073e9SAndroid Build Coastguard Worker
5974*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
5975*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
5976*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager")(fn)
5977*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
5978*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
5979*da0073e9SAndroid Build Coastguard Worker
5980*da0073e9SAndroid Build Coastguard Worker    def test_if_cond_user_defined_object(self):
5981*da0073e9SAndroid Build Coastguard Worker        # obj.__bool__ is not existed
5982*da0073e9SAndroid Build Coastguard Worker        class A:  # noqa: B903
5983*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
5984*da0073e9SAndroid Build Coastguard Worker                self.x = x
5985*da0073e9SAndroid Build Coastguard Worker
5986*da0073e9SAndroid Build Coastguard Worker        # obj.__bool__ is function and returns bool type
5987*da0073e9SAndroid Build Coastguard Worker        class B:
5988*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
5989*da0073e9SAndroid Build Coastguard Worker                self.x = x
5990*da0073e9SAndroid Build Coastguard Worker
5991*da0073e9SAndroid Build Coastguard Worker            def __bool__(self):
5992*da0073e9SAndroid Build Coastguard Worker                return self.x > 0
5993*da0073e9SAndroid Build Coastguard Worker
5994*da0073e9SAndroid Build Coastguard Worker        # obj.__bool__ is non-function
5995*da0073e9SAndroid Build Coastguard Worker        class C:
5996*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
5997*da0073e9SAndroid Build Coastguard Worker                self.x = x
5998*da0073e9SAndroid Build Coastguard Worker                self.__bool__ = False
5999*da0073e9SAndroid Build Coastguard Worker
6000*da0073e9SAndroid Build Coastguard Worker        def fn(x, obj):
6001*da0073e9SAndroid Build Coastguard Worker            if not obj:
6002*da0073e9SAndroid Build Coastguard Worker                return x + 1
6003*da0073e9SAndroid Build Coastguard Worker            else:
6004*da0073e9SAndroid Build Coastguard Worker                return x - 1
6005*da0073e9SAndroid Build Coastguard Worker
6006*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
6007*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
6008*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
6009*da0073e9SAndroid Build Coastguard Worker        obj1 = A(0.5)
6010*da0073e9SAndroid Build Coastguard Worker        obj2 = B(0.5)
6011*da0073e9SAndroid Build Coastguard Worker        obj3 = B(-0.5)
6012*da0073e9SAndroid Build Coastguard Worker        obj4 = C(0.5)
6013*da0073e9SAndroid Build Coastguard Worker        for obj in [obj1, obj2, obj3, obj4, obj3, obj2]:
6014*da0073e9SAndroid Build Coastguard Worker            ref = fn(x, obj)
6015*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x, obj)
6016*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref, res))
6017*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 4)
6018*da0073e9SAndroid Build Coastguard Worker
6019*da0073e9SAndroid Build Coastguard Worker    def test_if_cond_user_defined_object2(self):
6020*da0073e9SAndroid Build Coastguard Worker        # obj.__bool__ is function and returns non-bool type
6021*da0073e9SAndroid Build Coastguard Worker        class MyObj:
6022*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
6023*da0073e9SAndroid Build Coastguard Worker                self.x = x
6024*da0073e9SAndroid Build Coastguard Worker
6025*da0073e9SAndroid Build Coastguard Worker            def __bool__(self):
6026*da0073e9SAndroid Build Coastguard Worker                self.x = 1.2
6027*da0073e9SAndroid Build Coastguard Worker                return self.x
6028*da0073e9SAndroid Build Coastguard Worker
6029*da0073e9SAndroid Build Coastguard Worker        def fn(a, obj):
6030*da0073e9SAndroid Build Coastguard Worker            if not obj:
6031*da0073e9SAndroid Build Coastguard Worker                return a + obj.x
6032*da0073e9SAndroid Build Coastguard Worker            else:
6033*da0073e9SAndroid Build Coastguard Worker                return a - obj.x
6034*da0073e9SAndroid Build Coastguard Worker
6035*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
6036*da0073e9SAndroid Build Coastguard Worker        obj = MyObj(0.5)
6037*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6038*da0073e9SAndroid Build Coastguard Worker        try:
6039*da0073e9SAndroid Build Coastguard Worker            opt_fn(x, obj)
6040*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(True)
6041*da0073e9SAndroid Build Coastguard Worker        except TypeError as e:
6042*da0073e9SAndroid Build Coastguard Worker            self.assertIn("__bool__ should return bool, returned float", str(e))
6043*da0073e9SAndroid Build Coastguard Worker
6044*da0073e9SAndroid Build Coastguard Worker    def test_if_cond_user_defined_object3(self):
6045*da0073e9SAndroid Build Coastguard Worker        # obj.__bool__ is not existed, but obj.__len__ exists
6046*da0073e9SAndroid Build Coastguard Worker        class A:  # noqa: B903
6047*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
6048*da0073e9SAndroid Build Coastguard Worker                self.x = x
6049*da0073e9SAndroid Build Coastguard Worker
6050*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
6051*da0073e9SAndroid Build Coastguard Worker                return len(self.x)
6052*da0073e9SAndroid Build Coastguard Worker
6053*da0073e9SAndroid Build Coastguard Worker        # obj.__bool__ takes precedence over obj.__len__
6054*da0073e9SAndroid Build Coastguard Worker        class B:
6055*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
6056*da0073e9SAndroid Build Coastguard Worker                self.x = x
6057*da0073e9SAndroid Build Coastguard Worker
6058*da0073e9SAndroid Build Coastguard Worker            def __bool__(self):
6059*da0073e9SAndroid Build Coastguard Worker                return False
6060*da0073e9SAndroid Build Coastguard Worker
6061*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
6062*da0073e9SAndroid Build Coastguard Worker                return len(self.x)
6063*da0073e9SAndroid Build Coastguard Worker
6064*da0073e9SAndroid Build Coastguard Worker        def fn(x, obj):
6065*da0073e9SAndroid Build Coastguard Worker            if not obj:
6066*da0073e9SAndroid Build Coastguard Worker                return x + 1
6067*da0073e9SAndroid Build Coastguard Worker            else:
6068*da0073e9SAndroid Build Coastguard Worker                return x - 1
6069*da0073e9SAndroid Build Coastguard Worker
6070*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
6071*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
6072*da0073e9SAndroid Build Coastguard Worker        obj1 = A([1, 2, 3])
6073*da0073e9SAndroid Build Coastguard Worker        obj2 = A([])
6074*da0073e9SAndroid Build Coastguard Worker        obj3 = B([1, 2, 3])
6075*da0073e9SAndroid Build Coastguard Worker        obj4 = B([])
6076*da0073e9SAndroid Build Coastguard Worker        for obj in [obj1, obj2, obj3, obj4]:
6077*da0073e9SAndroid Build Coastguard Worker            ref = fn(x, obj)
6078*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x, obj)
6079*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref, res))
6080*da0073e9SAndroid Build Coastguard Worker
6081*da0073e9SAndroid Build Coastguard Worker    def test_class_has_instancecheck_method(self):
6082*da0073e9SAndroid Build Coastguard Worker        class A:
6083*da0073e9SAndroid Build Coastguard Worker            pass
6084*da0073e9SAndroid Build Coastguard Worker
6085*da0073e9SAndroid Build Coastguard Worker        class ExampleMeta(type):
6086*da0073e9SAndroid Build Coastguard Worker            def __instancecheck__(cls, instance):
6087*da0073e9SAndroid Build Coastguard Worker                return True
6088*da0073e9SAndroid Build Coastguard Worker
6089*da0073e9SAndroid Build Coastguard Worker        class B(metaclass=ExampleMeta):
6090*da0073e9SAndroid Build Coastguard Worker            pass
6091*da0073e9SAndroid Build Coastguard Worker
6092*da0073e9SAndroid Build Coastguard Worker        def fn(x, obj):
6093*da0073e9SAndroid Build Coastguard Worker            if isinstance(obj, B):
6094*da0073e9SAndroid Build Coastguard Worker                return x + 1
6095*da0073e9SAndroid Build Coastguard Worker            else:
6096*da0073e9SAndroid Build Coastguard Worker                return x - 1
6097*da0073e9SAndroid Build Coastguard Worker
6098*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
6099*da0073e9SAndroid Build Coastguard Worker        obj = A()
6100*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, obj)
6101*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6102*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, obj)
6103*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6104*da0073e9SAndroid Build Coastguard Worker
6105*da0073e9SAndroid Build Coastguard Worker    def test_torch_cuda_is_available(self):
6106*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6107*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.is_available():
6108*da0073e9SAndroid Build Coastguard Worker                return x + 1
6109*da0073e9SAndroid Build Coastguard Worker            else:
6110*da0073e9SAndroid Build Coastguard Worker                return x - 1
6111*da0073e9SAndroid Build Coastguard Worker
6112*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
6113*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
6114*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6115*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6116*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6117*da0073e9SAndroid Build Coastguard Worker
6118*da0073e9SAndroid Build Coastguard Worker    def test_variable_tracker_recursively_contains(self):
6119*da0073e9SAndroid Build Coastguard Worker        # VariableTracker.recursively_contains should be updated correctly when mutation happens
6120*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6121*da0073e9SAndroid Build Coastguard Worker            data = [[None] * 3] * 3
6122*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
6123*da0073e9SAndroid Build Coastguard Worker                if i == 0:
6124*da0073e9SAndroid Build Coastguard Worker                    data[0][i] = x
6125*da0073e9SAndroid Build Coastguard Worker                else:
6126*da0073e9SAndroid Build Coastguard Worker                    data[0][i] = data[0][i - 1] + 1
6127*da0073e9SAndroid Build Coastguard Worker            return data[0][-1]
6128*da0073e9SAndroid Build Coastguard Worker
6129*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
6130*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
6131*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6132*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6133*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6134*da0073e9SAndroid Build Coastguard Worker
6135*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires cuda")
6136*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
6137*da0073e9SAndroid Build Coastguard Worker    def test_torch_cudnn_is_acceptable(self):
6138*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6139*da0073e9SAndroid Build Coastguard Worker            if torch.backends.cudnn.is_acceptable(tensor=x):
6140*da0073e9SAndroid Build Coastguard Worker                return x + 1
6141*da0073e9SAndroid Build Coastguard Worker            return x
6142*da0073e9SAndroid Build Coastguard Worker
6143*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4).cuda()
6144*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
6145*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6146*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6147*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6148*da0073e9SAndroid Build Coastguard Worker
6149*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires cuda")
6150*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
6151*da0073e9SAndroid Build Coastguard Worker    def test_torch_cudnn_is_acceptable_bad_inputs(self):
6152*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
6153*da0073e9SAndroid Build Coastguard Worker            if torch.backends.cudnn.is_acceptable("invalid"):
6154*da0073e9SAndroid Build Coastguard Worker                return x + 1
6155*da0073e9SAndroid Build Coastguard Worker            return x
6156*da0073e9SAndroid Build Coastguard Worker
6157*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
6158*da0073e9SAndroid Build Coastguard Worker            if torch.backends.cudnn.is_acceptable(x, 3.14):
6159*da0073e9SAndroid Build Coastguard Worker                return x + 1
6160*da0073e9SAndroid Build Coastguard Worker            return x
6161*da0073e9SAndroid Build Coastguard Worker
6162*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
6163*da0073e9SAndroid Build Coastguard Worker            AssertionError, "Expect input to cudnn.is_acceptable to be a tensor"
6164*da0073e9SAndroid Build Coastguard Worker        ):
6165*da0073e9SAndroid Build Coastguard Worker            x1 = torch.rand(4).cuda()
6166*da0073e9SAndroid Build Coastguard Worker            opt_fn1 = torch._dynamo.optimize("eager", nopython=True)(fn1)
6167*da0073e9SAndroid Build Coastguard Worker            res1 = opt_fn1(x1)
6168*da0073e9SAndroid Build Coastguard Worker
6169*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
6170*da0073e9SAndroid Build Coastguard Worker            AssertionError, "Expect 1 input to cudnn.is_acceptable"
6171*da0073e9SAndroid Build Coastguard Worker        ):
6172*da0073e9SAndroid Build Coastguard Worker            x2 = torch.rand(4).cuda()
6173*da0073e9SAndroid Build Coastguard Worker            opt_fn2 = torch._dynamo.optimize("eager", nopython=True)(fn2)
6174*da0073e9SAndroid Build Coastguard Worker            res = opt_fn2(x2)
6175*da0073e9SAndroid Build Coastguard Worker
6176*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires cuda")
6177*da0073e9SAndroid Build Coastguard Worker    def test_get_device(self):
6178*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6179*da0073e9SAndroid Build Coastguard Worker            x = x + 1
6180*da0073e9SAndroid Build Coastguard Worker            y = y + 1
6181*da0073e9SAndroid Build Coastguard Worker            return x.get_device(), y.get_device()
6182*da0073e9SAndroid Build Coastguard Worker
6183*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4, device="cuda")
6184*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(4, device="cpu")
6185*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
6186*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6187*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
6188*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6189*da0073e9SAndroid Build Coastguard Worker
6190*da0073e9SAndroid Build Coastguard Worker    def test_disable_flag(self):
6191*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
6192*da0073e9SAndroid Build Coastguard Worker
6193*da0073e9SAndroid Build Coastguard Worker        with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}):
6194*da0073e9SAndroid Build Coastguard Worker
6195*da0073e9SAndroid Build Coastguard Worker            def fn(x, y):
6196*da0073e9SAndroid Build Coastguard Worker                x = x + 1
6197*da0073e9SAndroid Build Coastguard Worker                y = y + 1
6198*da0073e9SAndroid Build Coastguard Worker
6199*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize(cnt)
6200*da0073e9SAndroid Build Coastguard Worker
6201*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 0)
6202*da0073e9SAndroid Build Coastguard Worker
6203*da0073e9SAndroid Build Coastguard Worker    def test_is_compiling(self):
6204*da0073e9SAndroid Build Coastguard Worker        def f1():
6205*da0073e9SAndroid Build Coastguard Worker            if torch._dynamo.is_compiling():
6206*da0073e9SAndroid Build Coastguard Worker                return torch.ones(2, 2)
6207*da0073e9SAndroid Build Coastguard Worker            else:
6208*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(2, 2)
6209*da0073e9SAndroid Build Coastguard Worker
6210*da0073e9SAndroid Build Coastguard Worker        def f2():
6211*da0073e9SAndroid Build Coastguard Worker            if torch._utils.is_compiling():
6212*da0073e9SAndroid Build Coastguard Worker                return torch.ones(2, 2)
6213*da0073e9SAndroid Build Coastguard Worker            else:
6214*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(2, 2)
6215*da0073e9SAndroid Build Coastguard Worker
6216*da0073e9SAndroid Build Coastguard Worker        def f3():
6217*da0073e9SAndroid Build Coastguard Worker            if torch.compiler.is_compiling():
6218*da0073e9SAndroid Build Coastguard Worker                return torch.ones(2, 2)
6219*da0073e9SAndroid Build Coastguard Worker            else:
6220*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(2, 2)
6221*da0073e9SAndroid Build Coastguard Worker
6222*da0073e9SAndroid Build Coastguard Worker        def f4():
6223*da0073e9SAndroid Build Coastguard Worker            if torch.compiler.is_dynamo_compiling():
6224*da0073e9SAndroid Build Coastguard Worker                return torch.ones(2, 2)
6225*da0073e9SAndroid Build Coastguard Worker            else:
6226*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(2, 2)
6227*da0073e9SAndroid Build Coastguard Worker
6228*da0073e9SAndroid Build Coastguard Worker        for f in [f1, f2, f3, f4]:
6229*da0073e9SAndroid Build Coastguard Worker            opt_f = torch._dynamo.optimize("eager")(f)
6230*da0073e9SAndroid Build Coastguard Worker
6231*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(), torch.zeros(2, 2))
6232*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(opt_f(), torch.ones(2, 2))
6233*da0073e9SAndroid Build Coastguard Worker
6234*da0073e9SAndroid Build Coastguard Worker    def test_torch_generator_set_state(self):
6235*da0073e9SAndroid Build Coastguard Worker        def fn():
6236*da0073e9SAndroid Build Coastguard Worker            default_state = torch.default_generator.get_state()
6237*da0073e9SAndroid Build Coastguard Worker            x = torch.rand([2, 3])
6238*da0073e9SAndroid Build Coastguard Worker            if default_state.dtype != "float32":
6239*da0073e9SAndroid Build Coastguard Worker                x = x * 2
6240*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
6241*da0073e9SAndroid Build Coastguard Worker            torch.default_generator.set_state(default_state)
6242*da0073e9SAndroid Build Coastguard Worker            y = torch.rand([2, 3])
6243*da0073e9SAndroid Build Coastguard Worker            return x, y
6244*da0073e9SAndroid Build Coastguard Worker
6245*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6246*da0073e9SAndroid Build Coastguard Worker        x, y = opt_fn()
6247*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y * 2)
6248*da0073e9SAndroid Build Coastguard Worker
6249*da0073e9SAndroid Build Coastguard Worker    def test_torch_distributions_lazy_property(self):
6250*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6251*da0073e9SAndroid Build Coastguard Worker            return torch.distributions.Categorical(probs=x).entropy()
6252*da0073e9SAndroid Build Coastguard Worker
6253*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6254*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4, 4])
6255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x), fn(x))
6256*da0073e9SAndroid Build Coastguard Worker
6257*da0073e9SAndroid Build Coastguard Worker    def test_guard_failure_fn(self):
6258*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, k):
6259*da0073e9SAndroid Build Coastguard Worker            x = x + 1
6260*da0073e9SAndroid Build Coastguard Worker            y = y + 1
6261*da0073e9SAndroid Build Coastguard Worker            return x * y * k
6262*da0073e9SAndroid Build Coastguard Worker
6263*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.5, 0.5])
6264*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([1.0, 1.0])
6265*da0073e9SAndroid Build Coastguard Worker
6266*da0073e9SAndroid Build Coastguard Worker        guard_failure = None
6267*da0073e9SAndroid Build Coastguard Worker
6268*da0073e9SAndroid Build Coastguard Worker        def guard_failures(failure):
6269*da0073e9SAndroid Build Coastguard Worker            nonlocal guard_failure
6270*da0073e9SAndroid Build Coastguard Worker            guard_failure = failure
6271*da0073e9SAndroid Build Coastguard Worker
6272*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(
6273*da0073e9SAndroid Build Coastguard Worker            "eager", nopython=True, guard_fail_fn=guard_failures
6274*da0073e9SAndroid Build Coastguard Worker        )(fn)
6275*da0073e9SAndroid Build Coastguard Worker
6276*da0073e9SAndroid Build Coastguard Worker        x2 = torch.tensor([0.5, 0.5, 1.0])
6277*da0073e9SAndroid Build Coastguard Worker        y2 = torch.tensor([0.5, 0.5, 0.5])
6278*da0073e9SAndroid Build Coastguard Worker
6279*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, y, 3)
6280*da0073e9SAndroid Build Coastguard Worker        opt_fn(x2, y2, 5)
6281*da0073e9SAndroid Build Coastguard Worker
6282*da0073e9SAndroid Build Coastguard Worker        if (
6283*da0073e9SAndroid Build Coastguard Worker            not torch._dynamo.config.specialize_int
6284*da0073e9SAndroid Build Coastguard Worker            and not torch._dynamo.config.assume_static_by_default
6285*da0073e9SAndroid Build Coastguard Worker        ):
6286*da0073e9SAndroid Build Coastguard Worker            # we didn't actually test guard_failure_fn here but whatever,
6287*da0073e9SAndroid Build Coastguard Worker            # nice to see no guard failure on the test
6288*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(guard_failure is None)
6289*da0073e9SAndroid Build Coastguard Worker        else:
6290*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(guard_failure is not None)
6291*da0073e9SAndroid Build Coastguard Worker
6292*da0073e9SAndroid Build Coastguard Worker    def test_guard_failure_fn_shape_control(self):
6293*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6294*da0073e9SAndroid Build Coastguard Worker            if x.shape[0] < 3:
6295*da0073e9SAndroid Build Coastguard Worker                if y.shape[0] < 3:
6296*da0073e9SAndroid Build Coastguard Worker                    return x * y
6297*da0073e9SAndroid Build Coastguard Worker                else:
6298*da0073e9SAndroid Build Coastguard Worker                    return x + y
6299*da0073e9SAndroid Build Coastguard Worker            else:
6300*da0073e9SAndroid Build Coastguard Worker                return -1
6301*da0073e9SAndroid Build Coastguard Worker
6302*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([2, 2])
6303*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([2, 2])
6304*da0073e9SAndroid Build Coastguard Worker
6305*da0073e9SAndroid Build Coastguard Worker        guard_failure = None
6306*da0073e9SAndroid Build Coastguard Worker
6307*da0073e9SAndroid Build Coastguard Worker        def guard_failures(failure):
6308*da0073e9SAndroid Build Coastguard Worker            nonlocal guard_failure
6309*da0073e9SAndroid Build Coastguard Worker            guard_failure = failure
6310*da0073e9SAndroid Build Coastguard Worker
6311*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(
6312*da0073e9SAndroid Build Coastguard Worker            "eager", nopython=True, guard_fail_fn=guard_failures
6313*da0073e9SAndroid Build Coastguard Worker        )(fn)
6314*da0073e9SAndroid Build Coastguard Worker
6315*da0073e9SAndroid Build Coastguard Worker        x2 = torch.randn([5, 5])
6316*da0073e9SAndroid Build Coastguard Worker        y2 = torch.randn([5, 5])
6317*da0073e9SAndroid Build Coastguard Worker
6318*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, y)
6319*da0073e9SAndroid Build Coastguard Worker        opt_fn(x2, y2)
6320*da0073e9SAndroid Build Coastguard Worker
6321*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(guard_failure is not None)
6322*da0073e9SAndroid Build Coastguard Worker        first_guard_failure = guard_failure[0].partition("\n")[0]
6323*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
6324*da0073e9SAndroid Build Coastguard Worker            self.assertIn(
6325*da0073e9SAndroid Build Coastguard Worker                """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
6326*da0073e9SAndroid Build Coastguard Worker                first_guard_failure,
6327*da0073e9SAndroid Build Coastguard Worker            )
6328*da0073e9SAndroid Build Coastguard Worker        else:
6329*da0073e9SAndroid Build Coastguard Worker            self.assertIn("""2 <= L['x'].size()[0] <= 2""", first_guard_failure)
6330*da0073e9SAndroid Build Coastguard Worker
6331*da0073e9SAndroid Build Coastguard Worker    def test_guard_failure_fn2(self):
6332*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6333*da0073e9SAndroid Build Coastguard Worker            x = x + 1
6334*da0073e9SAndroid Build Coastguard Worker            y = y + 1
6335*da0073e9SAndroid Build Coastguard Worker            return x * y
6336*da0073e9SAndroid Build Coastguard Worker
6337*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.5, 0.5])
6338*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([1.0, 1.0])
6339*da0073e9SAndroid Build Coastguard Worker
6340*da0073e9SAndroid Build Coastguard Worker        guard_failure = None
6341*da0073e9SAndroid Build Coastguard Worker
6342*da0073e9SAndroid Build Coastguard Worker        def guard_failures(failure):
6343*da0073e9SAndroid Build Coastguard Worker            nonlocal guard_failure
6344*da0073e9SAndroid Build Coastguard Worker            guard_failure = failure
6345*da0073e9SAndroid Build Coastguard Worker
6346*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(
6347*da0073e9SAndroid Build Coastguard Worker            "eager", nopython=True, guard_fail_fn=guard_failures
6348*da0073e9SAndroid Build Coastguard Worker        )(fn)
6349*da0073e9SAndroid Build Coastguard Worker
6350*da0073e9SAndroid Build Coastguard Worker        x2 = torch.tensor([0.5, 0.5, 1.0])
6351*da0073e9SAndroid Build Coastguard Worker        y2 = torch.tensor([0.5, 0.5, 0.5])
6352*da0073e9SAndroid Build Coastguard Worker
6353*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, y)
6354*da0073e9SAndroid Build Coastguard Worker        opt_fn(x2, y2)
6355*da0073e9SAndroid Build Coastguard Worker
6356*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
6357*da0073e9SAndroid Build Coastguard Worker            self.assertIn(
6358*da0073e9SAndroid Build Coastguard Worker                """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
6359*da0073e9SAndroid Build Coastguard Worker                guard_failure[0],
6360*da0073e9SAndroid Build Coastguard Worker            )
6361*da0073e9SAndroid Build Coastguard Worker        else:
6362*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(guard_failure is None)
6363*da0073e9SAndroid Build Coastguard Worker
6364*da0073e9SAndroid Build Coastguard Worker    def test_guard_failure_fn_tensor_iter(self):
6365*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6366*da0073e9SAndroid Build Coastguard Worker            for y in x:
6367*da0073e9SAndroid Build Coastguard Worker                y.add_(1.0)
6368*da0073e9SAndroid Build Coastguard Worker            return y
6369*da0073e9SAndroid Build Coastguard Worker
6370*da0073e9SAndroid Build Coastguard Worker        guard_failure = None
6371*da0073e9SAndroid Build Coastguard Worker
6372*da0073e9SAndroid Build Coastguard Worker        def guard_failures(failure):
6373*da0073e9SAndroid Build Coastguard Worker            nonlocal guard_failure
6374*da0073e9SAndroid Build Coastguard Worker            guard_failure = failure
6375*da0073e9SAndroid Build Coastguard Worker
6376*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(
6377*da0073e9SAndroid Build Coastguard Worker            "eager", nopython=True, guard_fail_fn=guard_failures
6378*da0073e9SAndroid Build Coastguard Worker        )(fn)
6379*da0073e9SAndroid Build Coastguard Worker
6380*da0073e9SAndroid Build Coastguard Worker        args1 = torch.randn(10, 10)
6381*da0073e9SAndroid Build Coastguard Worker        out = fn(args1)
6382*da0073e9SAndroid Build Coastguard Worker        opt_out = opt_fn(args1)
6383*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(out, opt_out))
6384*da0073e9SAndroid Build Coastguard Worker
6385*da0073e9SAndroid Build Coastguard Worker        args2 = torch.randn(9, 10)
6386*da0073e9SAndroid Build Coastguard Worker        out = fn(args2)
6387*da0073e9SAndroid Build Coastguard Worker        opt_out = opt_fn(args2)
6388*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(out, opt_out))
6389*da0073e9SAndroid Build Coastguard Worker
6390*da0073e9SAndroid Build Coastguard Worker        # guard is expected for both static and dynamic shapes
6391*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(guard_failure is not None)
6392*da0073e9SAndroid Build Coastguard Worker        self.assertIn(
6393*da0073e9SAndroid Build Coastguard Worker            """len(L['x']) == 10""",
6394*da0073e9SAndroid Build Coastguard Worker            guard_failure[0],
6395*da0073e9SAndroid Build Coastguard Worker        )
6396*da0073e9SAndroid Build Coastguard Worker
6397*da0073e9SAndroid Build Coastguard Worker    def test_restore_graphstate(self):
6398*da0073e9SAndroid Build Coastguard Worker        # This function does some guard accumulation,
6399*da0073e9SAndroid Build Coastguard Worker        # and then rolls back due to control flow.
6400*da0073e9SAndroid Build Coastguard Worker        # The idea is that if one were printing guards as they appear,
6401*da0073e9SAndroid Build Coastguard Worker        # they would see this insert a guard that does not show up in the final set of
6402*da0073e9SAndroid Build Coastguard Worker        # guards as we rolled back from it.
6403*da0073e9SAndroid Build Coastguard Worker        def nested_fn(s):
6404*da0073e9SAndroid Build Coastguard Worker            if x[0] < 10:
6405*da0073e9SAndroid Build Coastguard Worker                return s * s
6406*da0073e9SAndroid Build Coastguard Worker            return s
6407*da0073e9SAndroid Build Coastguard Worker
6408*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6409*da0073e9SAndroid Build Coastguard Worker            x = x + 1
6410*da0073e9SAndroid Build Coastguard Worker            y = nested_fn(y)
6411*da0073e9SAndroid Build Coastguard Worker            y = y + 10
6412*da0073e9SAndroid Build Coastguard Worker            return x * y
6413*da0073e9SAndroid Build Coastguard Worker
6414*da0073e9SAndroid Build Coastguard Worker        all_guards = []
6415*da0073e9SAndroid Build Coastguard Worker
6416*da0073e9SAndroid Build Coastguard Worker        def guard_export_print(guards):
6417*da0073e9SAndroid Build Coastguard Worker            nonlocal all_guards
6418*da0073e9SAndroid Build Coastguard Worker            all_guards.extend(guards)
6419*da0073e9SAndroid Build Coastguard Worker
6420*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", guard_export_fn=guard_export_print)(fn)
6421*da0073e9SAndroid Build Coastguard Worker
6422*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.5, 0.5])
6423*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([1.0, 1.0])
6424*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, y)
6425*da0073e9SAndroid Build Coastguard Worker
6426*da0073e9SAndroid Build Coastguard Worker        for guard in all_guards:
6427*da0073e9SAndroid Build Coastguard Worker            # This guard was created
6428*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents")
6429*da0073e9SAndroid Build Coastguard Worker
6430*da0073e9SAndroid Build Coastguard Worker    def test_call_parent_non_class_methods_from_child(self):
6431*da0073e9SAndroid Build Coastguard Worker        class A:
6432*da0073e9SAndroid Build Coastguard Worker            a = 4
6433*da0073e9SAndroid Build Coastguard Worker
6434*da0073e9SAndroid Build Coastguard Worker            def add(self, x):
6435*da0073e9SAndroid Build Coastguard Worker                return x + 10
6436*da0073e9SAndroid Build Coastguard Worker
6437*da0073e9SAndroid Build Coastguard Worker            def mul(self, x):
6438*da0073e9SAndroid Build Coastguard Worker                return x * 0.1
6439*da0073e9SAndroid Build Coastguard Worker
6440*da0073e9SAndroid Build Coastguard Worker        class B(A):
6441*da0073e9SAndroid Build Coastguard Worker            coeff = 4
6442*da0073e9SAndroid Build Coastguard Worker
6443*da0073e9SAndroid Build Coastguard Worker            def add(self, x):
6444*da0073e9SAndroid Build Coastguard Worker                return x + 20
6445*da0073e9SAndroid Build Coastguard Worker
6446*da0073e9SAndroid Build Coastguard Worker            @classmethod
6447*da0073e9SAndroid Build Coastguard Worker            def cube(cls, x):
6448*da0073e9SAndroid Build Coastguard Worker                return cls.coeff * x * x * x
6449*da0073e9SAndroid Build Coastguard Worker
6450*da0073e9SAndroid Build Coastguard Worker            def mul(self, x):
6451*da0073e9SAndroid Build Coastguard Worker                return super().mul(x) * x * 0.2
6452*da0073e9SAndroid Build Coastguard Worker
6453*da0073e9SAndroid Build Coastguard Worker        class C(B):
6454*da0073e9SAndroid Build Coastguard Worker            def add(self, x):
6455*da0073e9SAndroid Build Coastguard Worker                b = super().cube(x)
6456*da0073e9SAndroid Build Coastguard Worker                c = A.add(self, x)
6457*da0073e9SAndroid Build Coastguard Worker                d = B.mul(self, x)
6458*da0073e9SAndroid Build Coastguard Worker                e = super(B, self).add(x)
6459*da0073e9SAndroid Build Coastguard Worker                f = super().a * x
6460*da0073e9SAndroid Build Coastguard Worker                return b + c + d + e + f
6461*da0073e9SAndroid Build Coastguard Worker
6462*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
6463*da0073e9SAndroid Build Coastguard Worker        fn = C().add
6464*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
6465*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
6466*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
6467*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6468*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
6470*da0073e9SAndroid Build Coastguard Worker
6471*da0073e9SAndroid Build Coastguard Worker        # Check recompilation
6472*da0073e9SAndroid Build Coastguard Worker        A.a = 5
6473*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
6474*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6475*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6476*da0073e9SAndroid Build Coastguard Worker        # Ensure that super guard checks are working as expected
6477*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
6479*da0073e9SAndroid Build Coastguard Worker
6480*da0073e9SAndroid Build Coastguard Worker    def test_builder_for_class_with_metaclass(self):
6481*da0073e9SAndroid Build Coastguard Worker        class ExampleMeta(type):
6482*da0073e9SAndroid Build Coastguard Worker            pass
6483*da0073e9SAndroid Build Coastguard Worker
6484*da0073e9SAndroid Build Coastguard Worker        class MyClass(metaclass=ExampleMeta):
6485*da0073e9SAndroid Build Coastguard Worker            pass
6486*da0073e9SAndroid Build Coastguard Worker
6487*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6488*da0073e9SAndroid Build Coastguard Worker            if isinstance(y, MyClass):
6489*da0073e9SAndroid Build Coastguard Worker                return x + 1
6490*da0073e9SAndroid Build Coastguard Worker            else:
6491*da0073e9SAndroid Build Coastguard Worker                return x - 1
6492*da0073e9SAndroid Build Coastguard Worker
6493*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4, 4])
6494*da0073e9SAndroid Build Coastguard Worker        y = MyClass()
6495*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
6496*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6497*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
6498*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6499*da0073e9SAndroid Build Coastguard Worker
6500*da0073e9SAndroid Build Coastguard Worker    def test_tuple_from_tuple_iter(self):
6501*da0073e9SAndroid Build Coastguard Worker        def inner_fn(*args):
6502*da0073e9SAndroid Build Coastguard Worker            acc = torch.ones(10, 10)
6503*da0073e9SAndroid Build Coastguard Worker            for arg in args:
6504*da0073e9SAndroid Build Coastguard Worker                acc.add_(arg)
6505*da0073e9SAndroid Build Coastguard Worker
6506*da0073e9SAndroid Build Coastguard Worker            return acc
6507*da0073e9SAndroid Build Coastguard Worker
6508*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
6509*da0073e9SAndroid Build Coastguard Worker        def fn(inputs, params):
6510*da0073e9SAndroid Build Coastguard Worker            y = tuple(inputs) + tuple(params)
6511*da0073e9SAndroid Build Coastguard Worker            return inner_fn(*y)
6512*da0073e9SAndroid Build Coastguard Worker
6513*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.randn(10, 10) for _ in range(3)]
6514*da0073e9SAndroid Build Coastguard Worker
6515*da0073e9SAndroid Build Coastguard Worker        fn(inputs, iter(tuple(inputs)))
6516*da0073e9SAndroid Build Coastguard Worker
6517*da0073e9SAndroid Build Coastguard Worker        def fn(params):
6518*da0073e9SAndroid Build Coastguard Worker            y = tuple(params)
6519*da0073e9SAndroid Build Coastguard Worker            return inner_fn(*y)
6520*da0073e9SAndroid Build Coastguard Worker
6521*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6522*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.randn(10, 10) for _ in range(3)]
6523*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs)))))
6524*da0073e9SAndroid Build Coastguard Worker
6525*da0073e9SAndroid Build Coastguard Worker        # Force recompilation
6526*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.randn(10, 10) for _ in range(4)]
6527*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(iter(tuple(inputs))), opt_fn(iter(tuple(inputs)))))
6528*da0073e9SAndroid Build Coastguard Worker
6529*da0073e9SAndroid Build Coastguard Worker    def test_torch_package_working_with_trace(self):
6530*da0073e9SAndroid Build Coastguard Worker        # from torch._dynamo.test_case import run_tests
6531*da0073e9SAndroid Build Coastguard Worker
6532*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.randn([2, 2]), torch.randn([2, 2])]
6533*da0073e9SAndroid Build Coastguard Worker
6534*da0073e9SAndroid Build Coastguard Worker        optimized_model = torch._dynamo.optimize(backend="eager")(
6535*da0073e9SAndroid Build Coastguard Worker            MyPickledModule(torch.randn([2, 2]))
6536*da0073e9SAndroid Build Coastguard Worker        )
6537*da0073e9SAndroid Build Coastguard Worker        from torch import package
6538*da0073e9SAndroid Build Coastguard Worker
6539*da0073e9SAndroid Build Coastguard Worker        path = "/tmp/MyPickledModule.pt"
6540*da0073e9SAndroid Build Coastguard Worker        package_name = "MyPickledModule"
6541*da0073e9SAndroid Build Coastguard Worker        resource_name = "MyPickledModule.pkl"
6542*da0073e9SAndroid Build Coastguard Worker
6543*da0073e9SAndroid Build Coastguard Worker        model = MyPickledModule(torch.randn([2, 2]))
6544*da0073e9SAndroid Build Coastguard Worker
6545*da0073e9SAndroid Build Coastguard Worker        with package.PackageExporter(path) as exp:
6546*da0073e9SAndroid Build Coastguard Worker            exp.extern("**")
6547*da0073e9SAndroid Build Coastguard Worker            exp.save_pickle(package_name, resource_name, model)
6548*da0073e9SAndroid Build Coastguard Worker
6549*da0073e9SAndroid Build Coastguard Worker        imp = package.PackageImporter(path)
6550*da0073e9SAndroid Build Coastguard Worker        loaded_model = imp.load_pickle(package_name, resource_name)
6551*da0073e9SAndroid Build Coastguard Worker
6552*da0073e9SAndroid Build Coastguard Worker        optimized_loaded_model = torch._dynamo.optimize("eager")(loaded_model)(*inputs)
6553*da0073e9SAndroid Build Coastguard Worker
6554*da0073e9SAndroid Build Coastguard Worker    def test_shape_and_tuple_equality(self):
6555*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, t):
6556*da0073e9SAndroid Build Coastguard Worker            z = x * y
6557*da0073e9SAndroid Build Coastguard Worker            if x.size() == t:
6558*da0073e9SAndroid Build Coastguard Worker                return z.cos()
6559*da0073e9SAndroid Build Coastguard Worker            return z.sin()
6560*da0073e9SAndroid Build Coastguard Worker
6561*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize("eager", nopython=True)(fn)(
6562*da0073e9SAndroid Build Coastguard Worker            torch.randn([4, 4]), torch.randn([4, 4]), (4, 4)
6563*da0073e9SAndroid Build Coastguard Worker        )
6564*da0073e9SAndroid Build Coastguard Worker
6565*da0073e9SAndroid Build Coastguard Worker    def test_int_list(self):
6566*da0073e9SAndroid Build Coastguard Worker        # if assume_static_by_default == True: spec int list
6567*da0073e9SAndroid Build Coastguard Worker        # otherwise: unspec int list
6568*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6569*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x + y[1] % 2)
6570*da0073e9SAndroid Build Coastguard Worker
6571*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(6)
6572*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
6573*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
6574*da0073e9SAndroid Build Coastguard Worker        for i in range(10, 25, 3):
6575*da0073e9SAndroid Build Coastguard Worker            y = [i, i + 1, i + 2]
6576*da0073e9SAndroid Build Coastguard Worker            ref = fn(x, y)
6577*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x, y)
6578*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref, res))
6579*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
6580*da0073e9SAndroid Build Coastguard Worker            if torch._dynamo.config.automatic_dynamic_shapes:
6581*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(cnt.frame_count, """2""")
6582*da0073e9SAndroid Build Coastguard Worker            else:
6583*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(cnt.frame_count, """5""")
6584*da0073e9SAndroid Build Coastguard Worker        else:
6585*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
6586*da0073e9SAndroid Build Coastguard Worker
6587*da0073e9SAndroid Build Coastguard Worker    def test_patched_builtin_functions(self):
6588*da0073e9SAndroid Build Coastguard Worker        import builtins
6589*da0073e9SAndroid Build Coastguard Worker
6590*da0073e9SAndroid Build Coastguard Worker        # Cache the original builtin function ids
6591*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.trace_rules._builtin_function_ids()
6592*da0073e9SAndroid Build Coastguard Worker
6593*da0073e9SAndroid Build Coastguard Worker        class MyClass:
6594*da0073e9SAndroid Build Coastguard Worker            pass
6595*da0073e9SAndroid Build Coastguard Worker
6596*da0073e9SAndroid Build Coastguard Worker        builtin_isinstance = builtins.isinstance
6597*da0073e9SAndroid Build Coastguard Worker
6598*da0073e9SAndroid Build Coastguard Worker        def patched_isinstance(obj, classinfo) -> bool:
6599*da0073e9SAndroid Build Coastguard Worker            if builtin_isinstance(obj, MyClass):
6600*da0073e9SAndroid Build Coastguard Worker                return False
6601*da0073e9SAndroid Build Coastguard Worker            else:
6602*da0073e9SAndroid Build Coastguard Worker                return builtin_isinstance(obj, classinfo)
6603*da0073e9SAndroid Build Coastguard Worker
6604*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6605*da0073e9SAndroid Build Coastguard Worker            if isinstance(y, MyClass):
6606*da0073e9SAndroid Build Coastguard Worker                return x + 1
6607*da0073e9SAndroid Build Coastguard Worker            else:
6608*da0073e9SAndroid Build Coastguard Worker                return x - 1
6609*da0073e9SAndroid Build Coastguard Worker
6610*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 3)
6611*da0073e9SAndroid Build Coastguard Worker        y = MyClass()
6612*da0073e9SAndroid Build Coastguard Worker
6613*da0073e9SAndroid Build Coastguard Worker        try:
6614*da0073e9SAndroid Build Coastguard Worker            ref = fn(x, y)
6615*da0073e9SAndroid Build Coastguard Worker            # Monkey patch builtin function
6616*da0073e9SAndroid Build Coastguard Worker            builtins.isinstance = patched_isinstance
6617*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
6618*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x, y)
6619*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref, x + 1))
6620*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(res, x - 1))
6621*da0073e9SAndroid Build Coastguard Worker        finally:
6622*da0073e9SAndroid Build Coastguard Worker            builtins.isinstance = builtin_isinstance
6623*da0073e9SAndroid Build Coastguard Worker
6624*da0073e9SAndroid Build Coastguard Worker        # check recompilation because builtins is now unpatched
6625*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
6626*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
6627*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res, x + 1))
6628*da0073e9SAndroid Build Coastguard Worker
6629*da0073e9SAndroid Build Coastguard Worker    # specifically test for tensor.attribute -> torch.something()
6630*da0073e9SAndroid Build Coastguard Worker    def test_real_imag_tensor_attribute(self):
6631*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6632*da0073e9SAndroid Build Coastguard Worker            a = x.real
6633*da0073e9SAndroid Build Coastguard Worker            b = x.imag
6634*da0073e9SAndroid Build Coastguard Worker            return torch.mul(torch.add(a, y), b)
6635*da0073e9SAndroid Build Coastguard Worker
6636*da0073e9SAndroid Build Coastguard Worker        x_real = torch.rand((4, 4))
6637*da0073e9SAndroid Build Coastguard Worker        x_imag = torch.rand((4, 4))
6638*da0073e9SAndroid Build Coastguard Worker        x = torch.complex(x_real, x_imag)
6639*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((4, 4))
6640*da0073e9SAndroid Build Coastguard Worker
6641*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
6642*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6643*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
6644*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6645*da0073e9SAndroid Build Coastguard Worker
6646*da0073e9SAndroid Build Coastguard Worker    def test_cast(self):
6647*da0073e9SAndroid Build Coastguard Worker        from typing import cast
6648*da0073e9SAndroid Build Coastguard Worker
6649*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6650*da0073e9SAndroid Build Coastguard Worker            return cast(torch.Tensor, torch.add(x, 1.0))
6651*da0073e9SAndroid Build Coastguard Worker
6652*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
6653*da0073e9SAndroid Build Coastguard Worker
6654*da0073e9SAndroid Build Coastguard Worker        ref = fn(torch.ones(2, 2))
6655*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(torch.ones(2, 2))
6656*da0073e9SAndroid Build Coastguard Worker
6657*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6658*da0073e9SAndroid Build Coastguard Worker
6659*da0073e9SAndroid Build Coastguard Worker    def test_T_tensor_attribute(self):
6660*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6661*da0073e9SAndroid Build Coastguard Worker            a = x.T
6662*da0073e9SAndroid Build Coastguard Worker            return torch.add(a, y)
6663*da0073e9SAndroid Build Coastguard Worker
6664*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((4, 4))
6665*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((4, 4))
6666*da0073e9SAndroid Build Coastguard Worker
6667*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
6668*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6669*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
6670*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6671*da0073e9SAndroid Build Coastguard Worker
6672*da0073e9SAndroid Build Coastguard Worker    def test_recursive_tensor_attribute(self):
6673*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6674*da0073e9SAndroid Build Coastguard Worker            a = x.real.T
6675*da0073e9SAndroid Build Coastguard Worker            b = x.imag
6676*da0073e9SAndroid Build Coastguard Worker            return torch.mul(torch.add(a, y), b)
6677*da0073e9SAndroid Build Coastguard Worker
6678*da0073e9SAndroid Build Coastguard Worker        x_real = torch.rand((4, 4))
6679*da0073e9SAndroid Build Coastguard Worker        x_imag = torch.rand((4, 4))
6680*da0073e9SAndroid Build Coastguard Worker        x = torch.complex(x_real, x_imag)
6681*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((4, 4))
6682*da0073e9SAndroid Build Coastguard Worker
6683*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
6684*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6685*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
6686*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6687*da0073e9SAndroid Build Coastguard Worker
6688*da0073e9SAndroid Build Coastguard Worker    def test_assigning_function_to_object_attribute(self):
6689*da0073e9SAndroid Build Coastguard Worker        # user-defined functions which are object's attributes are not converted to bound methods
6690*da0073e9SAndroid Build Coastguard Worker        def my_add(*args):
6691*da0073e9SAndroid Build Coastguard Worker            a, b = args
6692*da0073e9SAndroid Build Coastguard Worker            return a + b
6693*da0073e9SAndroid Build Coastguard Worker
6694*da0073e9SAndroid Build Coastguard Worker        class MyClass:
6695*da0073e9SAndroid Build Coastguard Worker            def __init__(self, func):
6696*da0073e9SAndroid Build Coastguard Worker                self.add = func
6697*da0073e9SAndroid Build Coastguard Worker
6698*da0073e9SAndroid Build Coastguard Worker        obj = MyClass(my_add)
6699*da0073e9SAndroid Build Coastguard Worker
6700*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6701*da0073e9SAndroid Build Coastguard Worker            return obj.add(x, 2)
6702*da0073e9SAndroid Build Coastguard Worker
6703*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3)
6704*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
6705*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager")(fn)
6706*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6707*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6708*da0073e9SAndroid Build Coastguard Worker
6709*da0073e9SAndroid Build Coastguard Worker    def test_assigning_function_to_class_attribute(self):
6710*da0073e9SAndroid Build Coastguard Worker        # user-defined functions which are class's attributes are converted to bound methods
6711*da0073e9SAndroid Build Coastguard Worker        def my_add(*args):
6712*da0073e9SAndroid Build Coastguard Worker            obj, a, b = args
6713*da0073e9SAndroid Build Coastguard Worker            return obj.x + a + b
6714*da0073e9SAndroid Build Coastguard Worker
6715*da0073e9SAndroid Build Coastguard Worker        class MyClass:
6716*da0073e9SAndroid Build Coastguard Worker            add = my_add
6717*da0073e9SAndroid Build Coastguard Worker
6718*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
6719*da0073e9SAndroid Build Coastguard Worker                self.x = x
6720*da0073e9SAndroid Build Coastguard Worker
6721*da0073e9SAndroid Build Coastguard Worker        obj = MyClass(0.5)
6722*da0073e9SAndroid Build Coastguard Worker
6723*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6724*da0073e9SAndroid Build Coastguard Worker            return obj.add(x, 2)
6725*da0073e9SAndroid Build Coastguard Worker
6726*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3)
6727*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
6728*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager")(fn)
6729*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6730*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6731*da0073e9SAndroid Build Coastguard Worker
6732*da0073e9SAndroid Build Coastguard Worker    def test_tagging_tensors_simple(self):
6733*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
6734*da0073e9SAndroid Build Coastguard Worker            return x * y, x, y
6735*da0073e9SAndroid Build Coastguard Worker
6736*da0073e9SAndroid Build Coastguard Worker        a = torch.randn([3, 3])
6737*da0073e9SAndroid Build Coastguard Worker        a.tag = "a"
6738*da0073e9SAndroid Build Coastguard Worker        a.frog = "ribbity ribbit"
6739*da0073e9SAndroid Build Coastguard Worker        b = torch.randn([3, 3])
6740*da0073e9SAndroid Build Coastguard Worker        b.tag = "b"
6741*da0073e9SAndroid Build Coastguard Worker        b.frog = "ribbit"
6742*da0073e9SAndroid Build Coastguard Worker
6743*da0073e9SAndroid Build Coastguard Worker        exported = torch._dynamo.export(foo)(a, b)
6744*da0073e9SAndroid Build Coastguard Worker        out_graph = exported[0]
6745*da0073e9SAndroid Build Coastguard Worker
6746*da0073e9SAndroid Build Coastguard Worker        nodes = list(out_graph.graph.nodes)
6747*da0073e9SAndroid Build Coastguard Worker        placeholders = [node for node in nodes if node.op == "placeholder"]
6748*da0073e9SAndroid Build Coastguard Worker        all_tags = []
6749*da0073e9SAndroid Build Coastguard Worker        all_frogs = []
6750*da0073e9SAndroid Build Coastguard Worker        for placeholder in placeholders:
6751*da0073e9SAndroid Build Coastguard Worker            if "tensor_dict" in placeholder.meta:
6752*da0073e9SAndroid Build Coastguard Worker                all_tags.append(placeholder.meta["tensor_dict"]["tag"])
6753*da0073e9SAndroid Build Coastguard Worker                all_frogs.append(placeholder.meta["tensor_dict"]["frog"])
6754*da0073e9SAndroid Build Coastguard Worker
6755*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all_tags, ["a", "b"])
6756*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all_frogs, ["ribbity ribbit", "ribbit"])
6757*da0073e9SAndroid Build Coastguard Worker
6758*da0073e9SAndroid Build Coastguard Worker    def test_tagging_tensors_mix_used_unused_structure(self):
6759*da0073e9SAndroid Build Coastguard Worker        def pre_attention_state_ops(input, mems, state):
6760*da0073e9SAndroid Build Coastguard Worker            lc_key = state[0]
6761*da0073e9SAndroid Build Coastguard Worker            lc_val = state[1]
6762*da0073e9SAndroid Build Coastguard Worker            bar = []
6763*da0073e9SAndroid Build Coastguard Worker            for i in range(0, 4):
6764*da0073e9SAndroid Build Coastguard Worker                bar2 = []
6765*da0073e9SAndroid Build Coastguard Worker                for j in range(0, 3):
6766*da0073e9SAndroid Build Coastguard Worker                    bar2.append(
6767*da0073e9SAndroid Build Coastguard Worker                        lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
6768*da0073e9SAndroid Build Coastguard Worker                    )
6769*da0073e9SAndroid Build Coastguard Worker                bar.append(bar2)
6770*da0073e9SAndroid Build Coastguard Worker
6771*da0073e9SAndroid Build Coastguard Worker            return bar
6772*da0073e9SAndroid Build Coastguard Worker
6773*da0073e9SAndroid Build Coastguard Worker        mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
6774*da0073e9SAndroid Build Coastguard Worker        state = [
6775*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
6776*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
6777*da0073e9SAndroid Build Coastguard Worker        ]
6778*da0073e9SAndroid Build Coastguard Worker        i = torch.tensor(
6779*da0073e9SAndroid Build Coastguard Worker            [
6780*da0073e9SAndroid Build Coastguard Worker                [0.0313, -0.1487, -0.3846, -0.5321],
6781*da0073e9SAndroid Build Coastguard Worker                [-1.7073, 1.3331, -0.0890, -1.4935],
6782*da0073e9SAndroid Build Coastguard Worker                [-0.8314, -0.1862, -0.5935, 1.5232],
6783*da0073e9SAndroid Build Coastguard Worker            ]
6784*da0073e9SAndroid Build Coastguard Worker        )
6785*da0073e9SAndroid Build Coastguard Worker
6786*da0073e9SAndroid Build Coastguard Worker        mems.tag = "MEMS"
6787*da0073e9SAndroid Build Coastguard Worker        i.tag = "FOO"
6788*da0073e9SAndroid Build Coastguard Worker        state[0].tag = "STATE_0"
6789*da0073e9SAndroid Build Coastguard Worker        state[1].tag = "HMMM"
6790*da0073e9SAndroid Build Coastguard Worker
6791*da0073e9SAndroid Build Coastguard Worker        exported = torch._dynamo.export(pre_attention_state_ops)(i, mems, state)
6792*da0073e9SAndroid Build Coastguard Worker        out_graph = exported[0]
6793*da0073e9SAndroid Build Coastguard Worker
6794*da0073e9SAndroid Build Coastguard Worker        nodes = list(out_graph.graph.nodes)
6795*da0073e9SAndroid Build Coastguard Worker        placeholders = [node for node in nodes if node.op == "placeholder"]
6796*da0073e9SAndroid Build Coastguard Worker        all_tags = []
6797*da0073e9SAndroid Build Coastguard Worker        for placeholder in placeholders:
6798*da0073e9SAndroid Build Coastguard Worker            if "tensor_dict" in placeholder.meta:
6799*da0073e9SAndroid Build Coastguard Worker                all_tags.append(placeholder.meta["tensor_dict"]["tag"])
6800*da0073e9SAndroid Build Coastguard Worker
6801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all_tags, ["STATE_0", "HMMM"])
6802*da0073e9SAndroid Build Coastguard Worker
6803*da0073e9SAndroid Build Coastguard Worker    def test_get_custom_tensor_attribute(self):
6804*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6805*da0073e9SAndroid Build Coastguard Worker            return x.custom_attr * x
6806*da0073e9SAndroid Build Coastguard Worker
6807*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
6808*da0073e9SAndroid Build Coastguard Worker        x.custom_attr = 3.14
6809*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
6810*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6811*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6812*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6813*da0073e9SAndroid Build Coastguard Worker
6814*da0073e9SAndroid Build Coastguard Worker    def test_set_custom_tensor_attribute(self):
6815*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6816*da0073e9SAndroid Build Coastguard Worker            x.custom_attr = 3.14
6817*da0073e9SAndroid Build Coastguard Worker            return x.custom_attr * x
6818*da0073e9SAndroid Build Coastguard Worker
6819*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
6820*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
6821*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6822*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
6823*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
6824*da0073e9SAndroid Build Coastguard Worker
6825*da0073e9SAndroid Build Coastguard Worker    def test_unhandled_exception_in_dynamo(self):
6826*da0073e9SAndroid Build Coastguard Worker        # traceback.format_exc() approximates an unhandled exception
6827*da0073e9SAndroid Build Coastguard Worker        def f(a):
6828*da0073e9SAndroid Build Coastguard Worker            a += 1
6829*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("smoge")
6830*da0073e9SAndroid Build Coastguard Worker            return a
6831*da0073e9SAndroid Build Coastguard Worker
6832*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(f)
6833*da0073e9SAndroid Build Coastguard Worker        try:
6834*da0073e9SAndroid Build Coastguard Worker            opt_fn(torch.ones(2))
6835*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
6836*da0073e9SAndroid Build Coastguard Worker            self.assertIn("smoge", traceback.format_exc())
6837*da0073e9SAndroid Build Coastguard Worker
6838*da0073e9SAndroid Build Coastguard Worker    def test_unhandled_exception_in_dynamo2(self):
6839*da0073e9SAndroid Build Coastguard Worker        # segfaults in python 3.11 if shadow frame is freed improperly
6840*da0073e9SAndroid Build Coastguard Worker        from torch.testing import make_tensor
6841*da0073e9SAndroid Build Coastguard Worker
6842*da0073e9SAndroid Build Coastguard Worker        def fn():
6843*da0073e9SAndroid Build Coastguard Worker            # test that the errors are the same for dense and sparse versions
6844*da0073e9SAndroid Build Coastguard Worker            def test1(*, is_sparse):
6845*da0073e9SAndroid Build Coastguard Worker                # shapes must be compatible for matrix multiplication
6846*da0073e9SAndroid Build Coastguard Worker                a = make_tensor((2, 3), dtype=torch.float32, device="cpu")
6847*da0073e9SAndroid Build Coastguard Worker                if is_sparse:
6848*da0073e9SAndroid Build Coastguard Worker                    a_sparse = a.to_sparse_csr()
6849*da0073e9SAndroid Build Coastguard Worker                    return torch.addmm(a, a_sparse, a)
6850*da0073e9SAndroid Build Coastguard Worker                else:
6851*da0073e9SAndroid Build Coastguard Worker                    return torch.addmm(a, a, a)
6852*da0073e9SAndroid Build Coastguard Worker
6853*da0073e9SAndroid Build Coastguard Worker            try:
6854*da0073e9SAndroid Build Coastguard Worker                test1(is_sparse=False)
6855*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as msg:
6856*da0073e9SAndroid Build Coastguard Worker                try:
6857*da0073e9SAndroid Build Coastguard Worker                    test1(is_sparse=True)
6858*da0073e9SAndroid Build Coastguard Worker                except RuntimeError as msg2:
6859*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("smoge")
6860*da0073e9SAndroid Build Coastguard Worker
6861*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6862*da0073e9SAndroid Build Coastguard Worker        try:
6863*da0073e9SAndroid Build Coastguard Worker            opt_fn()
6864*da0073e9SAndroid Build Coastguard Worker        except RuntimeError:
6865*da0073e9SAndroid Build Coastguard Worker            self.assertIn("smoge", traceback.format_exc())
6866*da0073e9SAndroid Build Coastguard Worker
6867*da0073e9SAndroid Build Coastguard Worker    def test_variable_access_in_exception(self):
6868*da0073e9SAndroid Build Coastguard Worker        def fn():
6869*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(1)
6870*da0073e9SAndroid Build Coastguard Worker            try:
6871*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("bad")
6872*da0073e9SAndroid Build Coastguard Worker            except RuntimeError:
6873*da0073e9SAndroid Build Coastguard Worker                x += 1
6874*da0073e9SAndroid Build Coastguard Worker            return x
6875*da0073e9SAndroid Build Coastguard Worker
6876*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
6877*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(), torch.tensor([2.0]))
6878*da0073e9SAndroid Build Coastguard Worker
6879*da0073e9SAndroid Build Coastguard Worker    def test_nested_sequential_with(self):
6880*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6881*da0073e9SAndroid Build Coastguard Worker            with torch.set_grad_enabled(True):
6882*da0073e9SAndroid Build Coastguard Worker                with torch.set_grad_enabled(False):
6883*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
6884*da0073e9SAndroid Build Coastguard Worker                with torch.set_grad_enabled(True):
6885*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
6886*da0073e9SAndroid Build Coastguard Worker                return x
6887*da0073e9SAndroid Build Coastguard Worker
6888*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6889*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
6890*da0073e9SAndroid Build Coastguard Worker
6891*da0073e9SAndroid Build Coastguard Worker    def test_nested_sequential_try(self):
6892*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6893*da0073e9SAndroid Build Coastguard Worker            try:
6894*da0073e9SAndroid Build Coastguard Worker                try:
6895*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
6896*da0073e9SAndroid Build Coastguard Worker                except:
6897*da0073e9SAndroid Build Coastguard Worker                    pass
6898*da0073e9SAndroid Build Coastguard Worker                try:
6899*da0073e9SAndroid Build Coastguard Worker                    try:
6900*da0073e9SAndroid Build Coastguard Worker                        x = x + 1
6901*da0073e9SAndroid Build Coastguard Worker                    except:
6902*da0073e9SAndroid Build Coastguard Worker                        pass
6903*da0073e9SAndroid Build Coastguard Worker                except:
6904*da0073e9SAndroid Build Coastguard Worker                    pass
6905*da0073e9SAndroid Build Coastguard Worker            except:
6906*da0073e9SAndroid Build Coastguard Worker                pass
6907*da0073e9SAndroid Build Coastguard Worker            return x
6908*da0073e9SAndroid Build Coastguard Worker
6909*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6910*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
6911*da0073e9SAndroid Build Coastguard Worker
6912*da0073e9SAndroid Build Coastguard Worker    def test_nested_sequential_try_with(self):
6913*da0073e9SAndroid Build Coastguard Worker        def fn(x):
6914*da0073e9SAndroid Build Coastguard Worker            with torch.set_grad_enabled(True):
6915*da0073e9SAndroid Build Coastguard Worker                try:
6916*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
6917*da0073e9SAndroid Build Coastguard Worker                except:
6918*da0073e9SAndroid Build Coastguard Worker                    pass
6919*da0073e9SAndroid Build Coastguard Worker                try:
6920*da0073e9SAndroid Build Coastguard Worker                    with torch.set_grad_enabled(False):
6921*da0073e9SAndroid Build Coastguard Worker                        x = x + 1
6922*da0073e9SAndroid Build Coastguard Worker                except:
6923*da0073e9SAndroid Build Coastguard Worker                    pass
6924*da0073e9SAndroid Build Coastguard Worker            return x
6925*da0073e9SAndroid Build Coastguard Worker
6926*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
6927*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
6928*da0073e9SAndroid Build Coastguard Worker
6929*da0073e9SAndroid Build Coastguard Worker    def test_nested_sequential_try_with_graph_break(self):
6930*da0073e9SAndroid Build Coastguard Worker        def fn(x, n):
6931*da0073e9SAndroid Build Coastguard Worker            with torch.set_grad_enabled(True):
6932*da0073e9SAndroid Build Coastguard Worker                with torch.set_grad_enabled(False):
6933*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
6934*da0073e9SAndroid Build Coastguard Worker                    torch._dynamo.graph_break()
6935*da0073e9SAndroid Build Coastguard Worker                try:
6936*da0073e9SAndroid Build Coastguard Worker                    with torch.set_grad_enabled(False):
6937*da0073e9SAndroid Build Coastguard Worker                        x = x + 1
6938*da0073e9SAndroid Build Coastguard Worker                        if n == 0:
6939*da0073e9SAndroid Build Coastguard Worker                            torch._dynamo.graph_break()
6940*da0073e9SAndroid Build Coastguard Worker                except:
6941*da0073e9SAndroid Build Coastguard Worker                    pass
6942*da0073e9SAndroid Build Coastguard Worker                with torch.set_grad_enabled(False):
6943*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
6944*da0073e9SAndroid Build Coastguard Worker                    torch._dynamo.graph_break()
6945*da0073e9SAndroid Build Coastguard Worker                x = x + 1
6946*da0073e9SAndroid Build Coastguard Worker            return x
6947*da0073e9SAndroid Build Coastguard Worker
6948*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
6949*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(counter)(fn)
6950*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(torch.ones(1), 0), torch.tensor([5.0]))
6951*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
6952*da0073e9SAndroid Build Coastguard Worker
6953*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
6954*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
6955*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(counter)(fn)
6956*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(torch.ones(1), 1), torch.tensor([5.0]))
6957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 3)
6958*da0073e9SAndroid Build Coastguard Worker
6959*da0073e9SAndroid Build Coastguard Worker    def test_ordered_dict_alias_reconstruct(self):
6960*da0073e9SAndroid Build Coastguard Worker        od = collections.OrderedDict
6961*da0073e9SAndroid Build Coastguard Worker
6962*da0073e9SAndroid Build Coastguard Worker        def fn():
6963*da0073e9SAndroid Build Coastguard Worker            d1 = dict()
6964*da0073e9SAndroid Build Coastguard Worker            d1["a"] = 1
6965*da0073e9SAndroid Build Coastguard Worker            d2 = od(d1)
6966*da0073e9SAndroid Build Coastguard Worker            d2["b"] = 2
6967*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
6968*da0073e9SAndroid Build Coastguard Worker            if isinstance(d2, od):
6969*da0073e9SAndroid Build Coastguard Worker                return d2["a"] + d2["b"]
6970*da0073e9SAndroid Build Coastguard Worker            else:
6971*da0073e9SAndroid Build Coastguard Worker                return 0
6972*da0073e9SAndroid Build Coastguard Worker
6973*da0073e9SAndroid Build Coastguard Worker        dis.dis(fn)
6974*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._dynamo.optimize("eager")(fn)(), 3)
6975*da0073e9SAndroid Build Coastguard Worker
6976*da0073e9SAndroid Build Coastguard Worker    # NOTE this test can be removed once multiline errors are in Python.
6977*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/python/cpython/issues/106922
6978*da0073e9SAndroid Build Coastguard Worker    @skipIfNotPy311
6979*da0073e9SAndroid Build Coastguard Worker    def test_get_instruction_source_311(self):
6980*da0073e9SAndroid Build Coastguard Worker        def f():
6981*da0073e9SAndroid Build Coastguard Worker            # flake8: noqa
6982*da0073e9SAndroid Build Coastguard Worker            # fmt: off
6983*da0073e9SAndroid Build Coastguard Worker            # test binary ops
6984*da0073e9SAndroid Build Coastguard Worker            a = ( b   )   +   c
6985*da0073e9SAndroid Build Coastguard Worker            a = (a + b) // (c - d)
6986*da0073e9SAndroid Build Coastguard Worker            a = b    \
6987*da0073e9SAndroid Build Coastguard Worker         +\
6988*da0073e9SAndroid Build Coastguard Worker               c  # test
6989*da0073e9SAndroid Build Coastguard Worker            a = (
6990*da0073e9SAndroid Build Coastguard Worker                (b  # test +
6991*da0073e9SAndroid Build Coastguard Worker                    )  \
6992*da0073e9SAndroid Build Coastguard Worker                # +
6993*da0073e9SAndroid Build Coastguard Worker            << (
6994*da0073e9SAndroid Build Coastguard Worker
6995*da0073e9SAndroid Build Coastguard Worker                c  # test
6996*da0073e9SAndroid Build Coastguard Worker                \
6997*da0073e9SAndroid Build Coastguard Worker            )  # test
6998*da0073e9SAndroid Build Coastguard Worker            )
6999*da0073e9SAndroid Build Coastguard Worker
7000*da0073e9SAndroid Build Coastguard Worker            # test slice
7001*da0073e9SAndroid Build Coastguard Worker            a = bbb   [  ccc    ]
7002*da0073e9SAndroid Build Coastguard Worker            b = bbbbb \
7003*da0073e9SAndroid Build Coastguard Worker                [  ccc # test
7004*da0073e9SAndroid Build Coastguard Worker
7005*da0073e9SAndroid Build Coastguard Worker                 + ddd  \
7006*da0073e9SAndroid Build Coastguard Worker
7007*da0073e9SAndroid Build Coastguard Worker                ] # test
7008*da0073e9SAndroid Build Coastguard Worker            a = bbb[ccc][ddd][eee]
7009*da0073e9SAndroid Build Coastguard Worker
7010*da0073e9SAndroid Build Coastguard Worker            # test nested and multiline function calls
7011*da0073e9SAndroid Build Coastguard Worker            a = g(g(g(b)))
7012*da0073e9SAndroid Build Coastguard Worker            a = g(h(
7013*da0073e9SAndroid Build Coastguard Worker                g(b),
7014*da0073e9SAndroid Build Coastguard Worker                c
7015*da0073e9SAndroid Build Coastguard Worker            ))
7016*da0073e9SAndroid Build Coastguard Worker
7017*da0073e9SAndroid Build Coastguard Worker            # test chained function calls
7018*da0073e9SAndroid Build Coastguard Worker            a = (g(x).y)(
7019*da0073e9SAndroid Build Coastguard Worker                z
7020*da0073e9SAndroid Build Coastguard Worker            )(1)(2)
7021*da0073e9SAndroid Build Coastguard Worker
7022*da0073e9SAndroid Build Coastguard Worker            # test unicode (match traceback behavior)
7023*da0073e9SAndroid Build Coastguard Worker            a = ("������" +
7024*da0073e9SAndroid Build Coastguard Worker                + "����") + b
7025*da0073e9SAndroid Build Coastguard Worker
7026*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.utils import get_instruction_source_311
7027*da0073e9SAndroid Build Coastguard Worker
7028*da0073e9SAndroid Build Coastguard Worker        if sys.version_info >= (3, 12):
7029*da0073e9SAndroid Build Coastguard Worker            # Offsets changed in 3.12, e.g. due to removal of PRECALL inst
7030*da0073e9SAndroid Build Coastguard Worker            offsets = (3, 11, 15, 19, 23, 29, 35, 44, 53, 65)
7031*da0073e9SAndroid Build Coastguard Worker        else:
7032*da0073e9SAndroid Build Coastguard Worker            offsets = (3, 11, 15, 19, 23, 29, 35, 46, 58, 74)
7033*da0073e9SAndroid Build Coastguard Worker        insts = list(dis.get_instructions(f))
7034*da0073e9SAndroid Build Coastguard Worker        # uncomment to determine offsets
7035*da0073e9SAndroid Build Coastguard Worker        # print(*enumerate(insts), sep="\n")
7036*da0073e9SAndroid Build Coastguard Worker        all_sources = "\n".join(
7037*da0073e9SAndroid Build Coastguard Worker            get_instruction_source_311(f.__code__, insts[offset]) for offset in offsets
7038*da0073e9SAndroid Build Coastguard Worker        )
7039*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
7040*da0073e9SAndroid Build Coastguard Worker            all_sources,
7041*da0073e9SAndroid Build Coastguard Worker            """\
7042*da0073e9SAndroid Build Coastguard Worker            a = ( b   )   +   c
7043*da0073e9SAndroid Build Coastguard Worker                ~~~~~~~~~~^~~~~
7044*da0073e9SAndroid Build Coastguard Worker
7045*da0073e9SAndroid Build Coastguard Worker            a = (a + b) // (c - d)
7046*da0073e9SAndroid Build Coastguard Worker                ~~~~~~~~^^~~~~~~~~
7047*da0073e9SAndroid Build Coastguard Worker
7048*da0073e9SAndroid Build Coastguard Worker            a = b    \\
7049*da0073e9SAndroid Build Coastguard Worker                ~~~~~~
7050*da0073e9SAndroid Build Coastguard Worker         +\\
7051*da0073e9SAndroid Build Coastguard Worker         ^~
7052*da0073e9SAndroid Build Coastguard Worker               c  # test
7053*da0073e9SAndroid Build Coastguard Worker               ~
7054*da0073e9SAndroid Build Coastguard Worker
7055*da0073e9SAndroid Build Coastguard Worker                (b  # test +
7056*da0073e9SAndroid Build Coastguard Worker                ~~~~~~~~~~~~
7057*da0073e9SAndroid Build Coastguard Worker                    )  \\
7058*da0073e9SAndroid Build Coastguard Worker                    ~~~~
7059*da0073e9SAndroid Build Coastguard Worker                # +
7060*da0073e9SAndroid Build Coastguard Worker                ~~~
7061*da0073e9SAndroid Build Coastguard Worker            << (
7062*da0073e9SAndroid Build Coastguard Worker            ^^~~
7063*da0073e9SAndroid Build Coastguard Worker
7064*da0073e9SAndroid Build Coastguard Worker
7065*da0073e9SAndroid Build Coastguard Worker                c  # test
7066*da0073e9SAndroid Build Coastguard Worker                ~~~~~~~~~
7067*da0073e9SAndroid Build Coastguard Worker                \\
7068*da0073e9SAndroid Build Coastguard Worker                ~
7069*da0073e9SAndroid Build Coastguard Worker            )  # test
7070*da0073e9SAndroid Build Coastguard Worker            ~
7071*da0073e9SAndroid Build Coastguard Worker
7072*da0073e9SAndroid Build Coastguard Worker            a = bbb   [  ccc    ]
7073*da0073e9SAndroid Build Coastguard Worker                ~~~~~~^^^^^^^^^^^
7074*da0073e9SAndroid Build Coastguard Worker
7075*da0073e9SAndroid Build Coastguard Worker            b = bbbbb \\
7076*da0073e9SAndroid Build Coastguard Worker                ~~~~~~~
7077*da0073e9SAndroid Build Coastguard Worker                [  ccc # test
7078*da0073e9SAndroid Build Coastguard Worker                ^^^^^^^^^^^^^
7079*da0073e9SAndroid Build Coastguard Worker
7080*da0073e9SAndroid Build Coastguard Worker
7081*da0073e9SAndroid Build Coastguard Worker                 + ddd  \\
7082*da0073e9SAndroid Build Coastguard Worker                 ^^^^^^^^
7083*da0073e9SAndroid Build Coastguard Worker
7084*da0073e9SAndroid Build Coastguard Worker
7085*da0073e9SAndroid Build Coastguard Worker                ] # test
7086*da0073e9SAndroid Build Coastguard Worker                ^
7087*da0073e9SAndroid Build Coastguard Worker
7088*da0073e9SAndroid Build Coastguard Worker            a = bbb[ccc][ddd][eee]
7089*da0073e9SAndroid Build Coastguard Worker                ~~~~~~~~^^^^^
7090*da0073e9SAndroid Build Coastguard Worker
7091*da0073e9SAndroid Build Coastguard Worker            a = g(g(g(b)))
7092*da0073e9SAndroid Build Coastguard Worker                  ~^^^^^^
7093*da0073e9SAndroid Build Coastguard Worker
7094*da0073e9SAndroid Build Coastguard Worker            a = g(h(
7095*da0073e9SAndroid Build Coastguard Worker                  ~^
7096*da0073e9SAndroid Build Coastguard Worker                g(b),
7097*da0073e9SAndroid Build Coastguard Worker                ^^^^^
7098*da0073e9SAndroid Build Coastguard Worker                c
7099*da0073e9SAndroid Build Coastguard Worker                ^
7100*da0073e9SAndroid Build Coastguard Worker            ))
7101*da0073e9SAndroid Build Coastguard Worker            ^
7102*da0073e9SAndroid Build Coastguard Worker
7103*da0073e9SAndroid Build Coastguard Worker            a = (g(x).y)(
7104*da0073e9SAndroid Build Coastguard Worker                ~~~~~~~~~
7105*da0073e9SAndroid Build Coastguard Worker                z
7106*da0073e9SAndroid Build Coastguard Worker                ~
7107*da0073e9SAndroid Build Coastguard Worker            )(1)(2)
7108*da0073e9SAndroid Build Coastguard Worker            ~^^^
7109*da0073e9SAndroid Build Coastguard Worker""",
7110*da0073e9SAndroid Build Coastguard Worker        )
7111*da0073e9SAndroid Build Coastguard Worker        # test unicode (since assertExpectedInline doesn't support unicode)
7112*da0073e9SAndroid Build Coastguard Worker        op_offset = 74 if sys.version_info >= (3, 12) else 84
7113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
7114*da0073e9SAndroid Build Coastguard Worker            get_instruction_source_311(f.__code__, insts[op_offset]),
7115*da0073e9SAndroid Build Coastguard Worker            """\
7116*da0073e9SAndroid Build Coastguard Worker            a = ("������" +
7117*da0073e9SAndroid Build Coastguard Worker                ~~~~~~~~
7118*da0073e9SAndroid Build Coastguard Worker                + "����") + b
7119*da0073e9SAndroid Build Coastguard Worker                ~~~~~~~~^~~
7120*da0073e9SAndroid Build Coastguard Worker""",
7121*da0073e9SAndroid Build Coastguard Worker        )
7122*da0073e9SAndroid Build Coastguard Worker
7123*da0073e9SAndroid Build Coastguard Worker    def test_raise_guard_full_constraint(self):
7124*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([3, 3, 3])
7125*da0073e9SAndroid Build Coastguard Worker
7126*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(x):
7127*da0073e9SAndroid Build Coastguard Worker            if x.shape[0] == 3:
7128*da0073e9SAndroid Build Coastguard Worker                return x.sin()
7129*da0073e9SAndroid Build Coastguard Worker            return x.cos()
7130*da0073e9SAndroid Build Coastguard Worker
7131*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(y, 0)
7132*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ConstraintViolationError):
7133*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7134*da0073e9SAndroid Build Coastguard Worker
7135*da0073e9SAndroid Build Coastguard Worker    # Translation validation changes the exception type, don't run with it
7136*da0073e9SAndroid Build Coastguard Worker    @torch.fx.experimental._config.patch(translation_validation=False)
7137*da0073e9SAndroid Build Coastguard Worker    def test_mark_dynamic_with_ranges(self):
7138*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([8, 3, 3])
7139*da0073e9SAndroid Build Coastguard Worker
7140*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(x):
7141*da0073e9SAndroid Build Coastguard Worker            if x.shape[0] == 3:
7142*da0073e9SAndroid Build Coastguard Worker                return x.sin()
7143*da0073e9SAndroid Build Coastguard Worker            return x.cos()
7144*da0073e9SAndroid Build Coastguard Worker
7145*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(y, 0, min=2, max=5)
7146*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ConstraintViolationError):
7147*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7148*da0073e9SAndroid Build Coastguard Worker
7149*da0073e9SAndroid Build Coastguard Worker    def test_mark_static(self):
7150*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
7151*da0073e9SAndroid Build Coastguard Worker
7152*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(x):
7153*da0073e9SAndroid Build Coastguard Worker            return x.cos()
7154*da0073e9SAndroid Build Coastguard Worker
7155*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([3])
7156*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_static(y, 0)
7157*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter)(my_dyn_fn)(y)
7158*da0073e9SAndroid Build Coastguard Worker
7159*da0073e9SAndroid Build Coastguard Worker        z = torch.randn([4])
7160*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter)(my_dyn_fn)(z)
7161*da0073e9SAndroid Build Coastguard Worker
7162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
7163*da0073e9SAndroid Build Coastguard Worker
7164*da0073e9SAndroid Build Coastguard Worker    def test_no_raise_guard_partial_constraint(self):
7165*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([3, 3, 3])
7166*da0073e9SAndroid Build Coastguard Worker
7167*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(x):
7168*da0073e9SAndroid Build Coastguard Worker            if x.shape[0] > 3:
7169*da0073e9SAndroid Build Coastguard Worker                return x.sin()
7170*da0073e9SAndroid Build Coastguard Worker            return x.cos()
7171*da0073e9SAndroid Build Coastguard Worker
7172*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7173*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(y, 0)
7174*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
7175*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7176*da0073e9SAndroid Build Coastguard Worker
7177*da0073e9SAndroid Build Coastguard Worker    def test_no_raise_guard_partial_constraint_across_break(self):
7178*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([3, 3, 3])
7179*da0073e9SAndroid Build Coastguard Worker
7180*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(x, y):
7181*da0073e9SAndroid Build Coastguard Worker            z = x * y
7182*da0073e9SAndroid Build Coastguard Worker
7183*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
7184*da0073e9SAndroid Build Coastguard Worker            if z.shape[0] > 2:
7185*da0073e9SAndroid Build Coastguard Worker                return z.cos()
7186*da0073e9SAndroid Build Coastguard Worker
7187*da0073e9SAndroid Build Coastguard Worker            return x.cos()
7188*da0073e9SAndroid Build Coastguard Worker
7189*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7190*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(y, 0)
7191*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
7192*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7193*da0073e9SAndroid Build Coastguard Worker
7194*da0073e9SAndroid Build Coastguard Worker    # Sadly, this does not throw - we do not prop correctly across the graph break
7195*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
7196*da0073e9SAndroid Build Coastguard Worker    def test_raise_guard_partial_constraint_across_break(self):
7197*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([3, 3, 3])
7198*da0073e9SAndroid Build Coastguard Worker
7199*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(x, y):
7200*da0073e9SAndroid Build Coastguard Worker            z = x * y
7201*da0073e9SAndroid Build Coastguard Worker
7202*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
7203*da0073e9SAndroid Build Coastguard Worker            if z.shape[0] == 3:
7204*da0073e9SAndroid Build Coastguard Worker                return z.cos()
7205*da0073e9SAndroid Build Coastguard Worker
7206*da0073e9SAndroid Build Coastguard Worker            return x.cos()
7207*da0073e9SAndroid Build Coastguard Worker
7208*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7209*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(y, 0)
7210*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
7211*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7212*da0073e9SAndroid Build Coastguard Worker            Exception,
7213*da0073e9SAndroid Build Coastguard Worker        ):
7214*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7215*da0073e9SAndroid Build Coastguard Worker
7216*da0073e9SAndroid Build Coastguard Worker    def test_raise_guard_partial_constraint_no_graph_break(self):
7217*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([3, 3, 3])
7218*da0073e9SAndroid Build Coastguard Worker
7219*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(x, y):
7220*da0073e9SAndroid Build Coastguard Worker            z = x * y
7221*da0073e9SAndroid Build Coastguard Worker
7222*da0073e9SAndroid Build Coastguard Worker            if z.shape[0] == 3:
7223*da0073e9SAndroid Build Coastguard Worker                return z.cos()
7224*da0073e9SAndroid Build Coastguard Worker
7225*da0073e9SAndroid Build Coastguard Worker            return x.cos()
7226*da0073e9SAndroid Build Coastguard Worker
7227*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(y, 0)
7228*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ConstraintViolationError):
7229*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
7230*da0073e9SAndroid Build Coastguard Worker
7231*da0073e9SAndroid Build Coastguard Worker    def test_cannot_trace_mark_dynamic(self):
7232*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([3, 3, 3])
7233*da0073e9SAndroid Build Coastguard Worker
7234*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(x):
7235*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.mark_dynamic(x, 0)
7236*da0073e9SAndroid Build Coastguard Worker            return x * x
7237*da0073e9SAndroid Build Coastguard Worker
7238*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7239*da0073e9SAndroid Build Coastguard Worker            AssertionError, "Attempt to trace forbidden callable"
7240*da0073e9SAndroid Build Coastguard Worker        ):
7241*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7242*da0073e9SAndroid Build Coastguard Worker
7243*da0073e9SAndroid Build Coastguard Worker    def test_cannot_trace_mark_dynamic_safe_unreached(self):
7244*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([3, 3, 3])
7245*da0073e9SAndroid Build Coastguard Worker
7246*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(x):
7247*da0073e9SAndroid Build Coastguard Worker            if x.shape[0] == 3:
7248*da0073e9SAndroid Build Coastguard Worker                return x
7249*da0073e9SAndroid Build Coastguard Worker            print("Running", torch._dynamo.mark_dynamic(x, 0))
7250*da0073e9SAndroid Build Coastguard Worker            return x * x
7251*da0073e9SAndroid Build Coastguard Worker
7252*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize("eager")(my_dyn_fn)(y)
7253*da0073e9SAndroid Build Coastguard Worker
7254*da0073e9SAndroid Build Coastguard Worker    def test_anomaly_aot_autograd(self):
7255*da0073e9SAndroid Build Coastguard Worker        def fail():
7256*da0073e9SAndroid Build Coastguard Worker            raise AssertionError("fail")
7257*da0073e9SAndroid Build Coastguard Worker
7258*da0073e9SAndroid Build Coastguard Worker        @allow_in_graph
7259*da0073e9SAndroid Build Coastguard Worker        def h(a):
7260*da0073e9SAndroid Build Coastguard Worker            r = a.sum()
7261*da0073e9SAndroid Build Coastguard Worker            # Trigger an exception in backwards
7262*da0073e9SAndroid Build Coastguard Worker            r.register_hook(lambda x: fail())
7263*da0073e9SAndroid Build Coastguard Worker            return r
7264*da0073e9SAndroid Build Coastguard Worker
7265*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager")
7266*da0073e9SAndroid Build Coastguard Worker        def f(a):
7267*da0073e9SAndroid Build Coastguard Worker            return h(a)
7268*da0073e9SAndroid Build Coastguard Worker
7269*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w, self.assertRaises(
7270*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.exc.BackendCompilerFailed
7271*da0073e9SAndroid Build Coastguard Worker        ):
7272*da0073e9SAndroid Build Coastguard Worker            f(torch.randn(2, 2, requires_grad=True))
7273*da0073e9SAndroid Build Coastguard Worker
7274*da0073e9SAndroid Build Coastguard Worker        # Suppress unrelated pkg_resources warnings
7275*da0073e9SAndroid Build Coastguard Worker        self.assertIn("forward call that caused the error", str(w[-1].message))
7276*da0073e9SAndroid Build Coastguard Worker
7277*da0073e9SAndroid Build Coastguard Worker    def test_py_guards_mark_dynamic(self):
7278*da0073e9SAndroid Build Coastguard Worker        def my_dyn_fn(a):
7279*da0073e9SAndroid Build Coastguard Worker            if a.shape[0] > 2:
7280*da0073e9SAndroid Build Coastguard Worker                return a.cos()
7281*da0073e9SAndroid Build Coastguard Worker            return a.sin()
7282*da0073e9SAndroid Build Coastguard Worker
7283*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
7284*da0073e9SAndroid Build Coastguard Worker
7285*da0073e9SAndroid Build Coastguard Worker        # Run with dynamic
7286*da0073e9SAndroid Build Coastguard Worker        x0 = torch.randn([3, 3, 3])
7287*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(x0, 0)
7288*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter)(my_dyn_fn)(x0)
7289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
7290*da0073e9SAndroid Build Coastguard Worker
7291*da0073e9SAndroid Build Coastguard Worker        # Run without dynamic, no recompile
7292*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([3, 3, 3])
7293*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter)(my_dyn_fn)(x)
7294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
7295*da0073e9SAndroid Build Coastguard Worker
7296*da0073e9SAndroid Build Coastguard Worker        # Mark a new dim, 1, as dynamic
7297*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn([3, 3, 3])
7298*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(x1, 1)
7299*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter)(my_dyn_fn)(x1)
7300*da0073e9SAndroid Build Coastguard Worker        # Recompile triggered because we marked a new dym as dynamic
7301*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
7302*da0073e9SAndroid Build Coastguard Worker
7303*da0073e9SAndroid Build Coastguard Worker        # Reset
7304*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
7305*da0073e9SAndroid Build Coastguard Worker        # Reset counter
7306*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
7307*da0073e9SAndroid Build Coastguard Worker
7308*da0073e9SAndroid Build Coastguard Worker        # Run with dynamic 1
7309*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter)(my_dyn_fn)(x1)
7310*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
7311*da0073e9SAndroid Build Coastguard Worker
7312*da0073e9SAndroid Build Coastguard Worker        # Run with dynamic 0, not subset
7313*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter)(my_dyn_fn)(x0)
7314*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
7315*da0073e9SAndroid Build Coastguard Worker
7316*da0073e9SAndroid Build Coastguard Worker        # Run with dynamic 0, 1, 2, not subset
7317*da0073e9SAndroid Build Coastguard Worker        x012 = torch.randn([3, 3, 3])
7318*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(x012, 0)
7319*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(x012, 1)
7320*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic(x012, 2)
7321*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter)(my_dyn_fn)(x012)
7322*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 3)
7323*da0073e9SAndroid Build Coastguard Worker
7324*da0073e9SAndroid Build Coastguard Worker    def test_recompile_on_global_state_change(self):
7325*da0073e9SAndroid Build Coastguard Worker        last_state = []
7326*da0073e9SAndroid Build Coastguard Worker        cnt = 0
7327*da0073e9SAndroid Build Coastguard Worker
7328*da0073e9SAndroid Build Coastguard Worker        def my_compiler(gm, _):
7329*da0073e9SAndroid Build Coastguard Worker            nonlocal cnt
7330*da0073e9SAndroid Build Coastguard Worker            cnt += 1
7331*da0073e9SAndroid Build Coastguard Worker            state = read_state()
7332*da0073e9SAndroid Build Coastguard Worker
7333*da0073e9SAndroid Build Coastguard Worker            def inner(*args):
7334*da0073e9SAndroid Build Coastguard Worker                last_state[:] = state
7335*da0073e9SAndroid Build Coastguard Worker                return gm(*args)
7336*da0073e9SAndroid Build Coastguard Worker
7337*da0073e9SAndroid Build Coastguard Worker            return inner
7338*da0073e9SAndroid Build Coastguard Worker
7339*da0073e9SAndroid Build Coastguard Worker        def read_state():
7340*da0073e9SAndroid Build Coastguard Worker            return [
7341*da0073e9SAndroid Build Coastguard Worker                torch.is_grad_enabled(),
7342*da0073e9SAndroid Build Coastguard Worker                torch.are_deterministic_algorithms_enabled(),
7343*da0073e9SAndroid Build Coastguard Worker                torch._C._get_cublas_allow_tf32(),
7344*da0073e9SAndroid Build Coastguard Worker            ]
7345*da0073e9SAndroid Build Coastguard Worker
7346*da0073e9SAndroid Build Coastguard Worker        def write_state(state):
7347*da0073e9SAndroid Build Coastguard Worker            torch.set_grad_enabled(state[0]),
7348*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(state[1])
7349*da0073e9SAndroid Build Coastguard Worker            torch._C._set_cublas_allow_tf32(state[2]),
7350*da0073e9SAndroid Build Coastguard Worker
7351*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=my_compiler)
7352*da0073e9SAndroid Build Coastguard Worker        def fn(x):
7353*da0073e9SAndroid Build Coastguard Worker            return x + 1
7354*da0073e9SAndroid Build Coastguard Worker
7355*da0073e9SAndroid Build Coastguard Worker        initial_state = read_state()
7356*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
7357*da0073e9SAndroid Build Coastguard Worker        try:
7358*da0073e9SAndroid Build Coastguard Worker            for round in range(3):
7359*da0073e9SAndroid Build Coastguard Worker                for i in range(len(initial_state)):
7360*da0073e9SAndroid Build Coastguard Worker                    new_state = [False] * len(initial_state)
7361*da0073e9SAndroid Build Coastguard Worker                    new_state[i] = True
7362*da0073e9SAndroid Build Coastguard Worker                    write_state(new_state)
7363*da0073e9SAndroid Build Coastguard Worker                    assert read_state() == new_state
7364*da0073e9SAndroid Build Coastguard Worker                    last_state.clear()
7365*da0073e9SAndroid Build Coastguard Worker                    fn(y)
7366*da0073e9SAndroid Build Coastguard Worker                    assert last_state == new_state
7367*da0073e9SAndroid Build Coastguard Worker                    if round == 0:
7368*da0073e9SAndroid Build Coastguard Worker                        assert cnt == i + 1
7369*da0073e9SAndroid Build Coastguard Worker                    else:
7370*da0073e9SAndroid Build Coastguard Worker                        assert cnt == len(initial_state)
7371*da0073e9SAndroid Build Coastguard Worker        finally:
7372*da0073e9SAndroid Build Coastguard Worker            write_state(initial_state)
7373*da0073e9SAndroid Build Coastguard Worker
7374*da0073e9SAndroid Build Coastguard Worker    def test_grad_state_mutated(self):
7375*da0073e9SAndroid Build Coastguard Worker        prior = torch.is_grad_enabled()
7376*da0073e9SAndroid Build Coastguard Worker        value = None
7377*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
7378*da0073e9SAndroid Build Coastguard Worker
7379*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.allow_in_graph
7380*da0073e9SAndroid Build Coastguard Worker        def check_state():
7381*da0073e9SAndroid Build Coastguard Worker            nonlocal value
7382*da0073e9SAndroid Build Coastguard Worker            value = torch.is_grad_enabled()
7383*da0073e9SAndroid Build Coastguard Worker
7384*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
7385*da0073e9SAndroid Build Coastguard Worker        def fn(x):
7386*da0073e9SAndroid Build Coastguard Worker            check_state()
7387*da0073e9SAndroid Build Coastguard Worker            torch.set_grad_enabled(False)
7388*da0073e9SAndroid Build Coastguard Worker            return x + 1
7389*da0073e9SAndroid Build Coastguard Worker
7390*da0073e9SAndroid Build Coastguard Worker        try:
7391*da0073e9SAndroid Build Coastguard Worker            torch.set_grad_enabled(True)
7392*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(10))
7393*da0073e9SAndroid Build Coastguard Worker            assert value is True
7394*da0073e9SAndroid Build Coastguard Worker            assert torch.is_grad_enabled() is False
7395*da0073e9SAndroid Build Coastguard Worker
7396*da0073e9SAndroid Build Coastguard Worker            value = None
7397*da0073e9SAndroid Build Coastguard Worker            torch.set_grad_enabled(True)
7398*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(10))
7399*da0073e9SAndroid Build Coastguard Worker            assert value is True
7400*da0073e9SAndroid Build Coastguard Worker            assert torch.is_grad_enabled() is False
7401*da0073e9SAndroid Build Coastguard Worker
7402*da0073e9SAndroid Build Coastguard Worker            assert cnt.frame_count == 1
7403*da0073e9SAndroid Build Coastguard Worker        finally:
7404*da0073e9SAndroid Build Coastguard Worker            torch.set_grad_enabled(prior)
7405*da0073e9SAndroid Build Coastguard Worker
7406*da0073e9SAndroid Build Coastguard Worker    def test_deterministic_algorithms_mutated(self):
7407*da0073e9SAndroid Build Coastguard Worker        prior = torch.are_deterministic_algorithms_enabled()
7408*da0073e9SAndroid Build Coastguard Worker        prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
7409*da0073e9SAndroid Build Coastguard Worker        value = None
7410*da0073e9SAndroid Build Coastguard Worker        warn_only = None
7411*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
7412*da0073e9SAndroid Build Coastguard Worker
7413*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.allow_in_graph
7414*da0073e9SAndroid Build Coastguard Worker        def check_state():
7415*da0073e9SAndroid Build Coastguard Worker            nonlocal value
7416*da0073e9SAndroid Build Coastguard Worker            nonlocal warn_only
7417*da0073e9SAndroid Build Coastguard Worker            value = torch.are_deterministic_algorithms_enabled()
7418*da0073e9SAndroid Build Coastguard Worker            warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
7419*da0073e9SAndroid Build Coastguard Worker
7420*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
7421*da0073e9SAndroid Build Coastguard Worker        def fn(x):
7422*da0073e9SAndroid Build Coastguard Worker            check_state()
7423*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(False, warn_only=False)
7424*da0073e9SAndroid Build Coastguard Worker            return x + 1
7425*da0073e9SAndroid Build Coastguard Worker
7426*da0073e9SAndroid Build Coastguard Worker        def run_fn():
7427*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(True, warn_only=True)
7428*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(10))
7429*da0073e9SAndroid Build Coastguard Worker            assert value is True
7430*da0073e9SAndroid Build Coastguard Worker            assert warn_only is True
7431*da0073e9SAndroid Build Coastguard Worker            assert torch.are_deterministic_algorithms_enabled() is False
7432*da0073e9SAndroid Build Coastguard Worker            assert torch.is_deterministic_algorithms_warn_only_enabled() is False
7433*da0073e9SAndroid Build Coastguard Worker
7434*da0073e9SAndroid Build Coastguard Worker        try:
7435*da0073e9SAndroid Build Coastguard Worker            run_fn()
7436*da0073e9SAndroid Build Coastguard Worker            value, warn_only = None, None
7437*da0073e9SAndroid Build Coastguard Worker            run_fn()
7438*da0073e9SAndroid Build Coastguard Worker            assert cnt.frame_count == 1
7439*da0073e9SAndroid Build Coastguard Worker        finally:
7440*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(prior, warn_only=prior_warn_only)
7441*da0073e9SAndroid Build Coastguard Worker
7442*da0073e9SAndroid Build Coastguard Worker    def test_torch_compile_ctx_on_forward_and_training_step(self):
7443*da0073e9SAndroid Build Coastguard Worker        class MyModel(torch.nn.Module):
7444*da0073e9SAndroid Build Coastguard Worker            def forward(self):
7445*da0073e9SAndroid Build Coastguard Worker                ...
7446*da0073e9SAndroid Build Coastguard Worker
7447*da0073e9SAndroid Build Coastguard Worker            def training_step(self):
7448*da0073e9SAndroid Build Coastguard Worker                self()
7449*da0073e9SAndroid Build Coastguard Worker
7450*da0073e9SAndroid Build Coastguard Worker        model = MyModel()
7451*da0073e9SAndroid Build Coastguard Worker        compiled_model = torch.compile(model)
7452*da0073e9SAndroid Build Coastguard Worker
7453*da0073e9SAndroid Build Coastguard Worker        model.forward = compiled_model.dynamo_ctx(model.forward)
7454*da0073e9SAndroid Build Coastguard Worker        model.training_step = compiled_model.dynamo_ctx(model.training_step)
7455*da0073e9SAndroid Build Coastguard Worker
7456*da0073e9SAndroid Build Coastguard Worker        model.training_step()
7457*da0073e9SAndroid Build Coastguard Worker
7458*da0073e9SAndroid Build Coastguard Worker    def test_torch_guards_stack_frame_register_inlining(self):
7459*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.5, 0.5])
7460*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0.75, 0.75, 0.75, 0.75])
7461*da0073e9SAndroid Build Coastguard Worker        z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25])
7462*da0073e9SAndroid Build Coastguard Worker
7463*da0073e9SAndroid Build Coastguard Worker        def uwu_inline_me(x, y, z):
7464*da0073e9SAndroid Build Coastguard Worker            r = torch.cat((x, x)) + y
7465*da0073e9SAndroid Build Coastguard Worker            r2 = torch.cat((y, y)) + z
7466*da0073e9SAndroid Build Coastguard Worker            return r, r2
7467*da0073e9SAndroid Build Coastguard Worker
7468*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
7469*da0073e9SAndroid Build Coastguard Worker            r, r2 = uwu_inline_me(x, y, z)
7470*da0073e9SAndroid Build Coastguard Worker            return torch.mul(r, r), torch.mul(r2, r2)
7471*da0073e9SAndroid Build Coastguard Worker
7472*da0073e9SAndroid Build Coastguard Worker        seen_frames = []
7473*da0073e9SAndroid Build Coastguard Worker        import contextlib
7474*da0073e9SAndroid Build Coastguard Worker
7475*da0073e9SAndroid Build Coastguard Worker        @contextlib.contextmanager
7476*da0073e9SAndroid Build Coastguard Worker        def global_context_capture_fn(frame_summary):
7477*da0073e9SAndroid Build Coastguard Worker            if frame_summary is not None:
7478*da0073e9SAndroid Build Coastguard Worker                seen_frames.append(frame_summary)
7479*da0073e9SAndroid Build Coastguard Worker            yield
7480*da0073e9SAndroid Build Coastguard Worker
7481*da0073e9SAndroid Build Coastguard Worker        with mock.patch(
7482*da0073e9SAndroid Build Coastguard Worker            "torch._guards.TracingContext.current_frame",
7483*da0073e9SAndroid Build Coastguard Worker            side_effect=global_context_capture_fn,
7484*da0073e9SAndroid Build Coastguard Worker        ):
7485*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.optimize("eager")(fn)(x, y, z)
7486*da0073e9SAndroid Build Coastguard Worker
7487*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(seen_frames), 1)
7488*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(seen_frames[0].name, "fn")
7489*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
7490*da0073e9SAndroid Build Coastguard Worker
7491*da0073e9SAndroid Build Coastguard Worker    def test_torch_guards_stack_frame_register_inlining_deep(self):
7492*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.5, 0.5])
7493*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0.75, 0.75, 0.75, 0.75])
7494*da0073e9SAndroid Build Coastguard Worker        z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25])
7495*da0073e9SAndroid Build Coastguard Worker
7496*da0073e9SAndroid Build Coastguard Worker        def uwu_inline_me_deep(x, y):
7497*da0073e9SAndroid Build Coastguard Worker            return torch.cat((x, x)) + y
7498*da0073e9SAndroid Build Coastguard Worker
7499*da0073e9SAndroid Build Coastguard Worker        def uwu_inline_me(x, y, z):
7500*da0073e9SAndroid Build Coastguard Worker            r = uwu_inline_me_deep(x, y)
7501*da0073e9SAndroid Build Coastguard Worker            r2 = uwu_inline_me_deep(y, z)
7502*da0073e9SAndroid Build Coastguard Worker            return r, r2
7503*da0073e9SAndroid Build Coastguard Worker
7504*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
7505*da0073e9SAndroid Build Coastguard Worker            r, r2 = uwu_inline_me(x, y, z)
7506*da0073e9SAndroid Build Coastguard Worker            return torch.mul(r, r), torch.mul(r2, r2)
7507*da0073e9SAndroid Build Coastguard Worker
7508*da0073e9SAndroid Build Coastguard Worker        seen_frames = []
7509*da0073e9SAndroid Build Coastguard Worker        import contextlib
7510*da0073e9SAndroid Build Coastguard Worker
7511*da0073e9SAndroid Build Coastguard Worker        @contextlib.contextmanager
7512*da0073e9SAndroid Build Coastguard Worker        def global_context_capture_fn(frame_summary):
7513*da0073e9SAndroid Build Coastguard Worker            if frame_summary is not None:
7514*da0073e9SAndroid Build Coastguard Worker                seen_frames.append(frame_summary)
7515*da0073e9SAndroid Build Coastguard Worker            yield
7516*da0073e9SAndroid Build Coastguard Worker
7517*da0073e9SAndroid Build Coastguard Worker        with mock.patch(
7518*da0073e9SAndroid Build Coastguard Worker            "torch._guards.TracingContext.current_frame",
7519*da0073e9SAndroid Build Coastguard Worker            side_effect=global_context_capture_fn,
7520*da0073e9SAndroid Build Coastguard Worker        ):
7521*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.optimize("eager")(fn)(x, y, z)
7522*da0073e9SAndroid Build Coastguard Worker
7523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(seen_frames), 3)
7524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(seen_frames[0].name, "fn")
7525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(seen_frames[1].name, "uwu_inline_me")
7526*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)")
7527*da0073e9SAndroid Build Coastguard Worker
7528*da0073e9SAndroid Build Coastguard Worker    def test_error_on_recompile(self):
7529*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
7530*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
7531*da0073e9SAndroid Build Coastguard Worker            return a + b
7532*da0073e9SAndroid Build Coastguard Worker
7533*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
7534*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(torch._dynamo.exc.RecompileError):
7535*da0073e9SAndroid Build Coastguard Worker                fn(torch.rand(2, 3), torch.rand(2, 3))
7536*da0073e9SAndroid Build Coastguard Worker                fn(torch.rand(2, 3), (1, 2, 3))
7537*da0073e9SAndroid Build Coastguard Worker
7538*da0073e9SAndroid Build Coastguard Worker    @expectedFailureDynamic
7539*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(automatic_dynamic_shapes=False)
7540*da0073e9SAndroid Build Coastguard Worker    def test_compile_profiler(self):
7541*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
7542*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
7543*da0073e9SAndroid Build Coastguard Worker                return input + input
7544*da0073e9SAndroid Build Coastguard Worker
7545*da0073e9SAndroid Build Coastguard Worker        model = Model()
7546*da0073e9SAndroid Build Coastguard Worker        prof = CompileProfiler()
7547*da0073e9SAndroid Build Coastguard Worker        compiled = torch.compile(model, backend=prof)
7548*da0073e9SAndroid Build Coastguard Worker        base_checker = (
7549*da0073e9SAndroid Build Coastguard Worker            lambda: FileCheck()
7550*da0073e9SAndroid Build Coastguard Worker            .check("Torchdynamo Profiler Report")
7551*da0073e9SAndroid Build Coastguard Worker            .check("Graph Breaks")
7552*da0073e9SAndroid Build Coastguard Worker            .check("No graph breaks detected.")
7553*da0073e9SAndroid Build Coastguard Worker            .check("Recompilation")
7554*da0073e9SAndroid Build Coastguard Worker        )
7555*da0073e9SAndroid Build Coastguard Worker        input = torch.rand((2, 3, 4))
7556*da0073e9SAndroid Build Coastguard Worker        _ = compiled(input)
7557*da0073e9SAndroid Build Coastguard Worker        base_checker().check("No recompilation detected.").run(prof.report())
7558*da0073e9SAndroid Build Coastguard Worker
7559*da0073e9SAndroid Build Coastguard Worker        new_shape_input = torch.rand((3, 3, 4))
7560*da0073e9SAndroid Build Coastguard Worker        _ = compiled(new_shape_input)
7561*da0073e9SAndroid Build Coastguard Worker
7562*da0073e9SAndroid Build Coastguard Worker        # Not an exhaustive test of dynamic shapes behavior, but some sanity
7563*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
7564*da0073e9SAndroid Build Coastguard Worker            base_checker().check("Recompile Reasons").check("'forward'").check(
7565*da0073e9SAndroid Build Coastguard Worker                "cache_size_limit to 1"
7566*da0073e9SAndroid Build Coastguard Worker            ).run(prof.report())
7567*da0073e9SAndroid Build Coastguard Worker        else:
7568*da0073e9SAndroid Build Coastguard Worker            base_checker().check("No recompilation detected.").run(prof.report())
7569*da0073e9SAndroid Build Coastguard Worker
7570*da0073e9SAndroid Build Coastguard Worker        new_shape_input = torch.rand((4, 3, 4))
7571*da0073e9SAndroid Build Coastguard Worker        _ = compiled(new_shape_input)
7572*da0073e9SAndroid Build Coastguard Worker
7573*da0073e9SAndroid Build Coastguard Worker        base_checker().check("Recompile Reasons").check("'forward'").check(
7574*da0073e9SAndroid Build Coastguard Worker            "tensor 'L['input']' size mismatch at index 0. expected 2, actual 3"
7575*da0073e9SAndroid Build Coastguard Worker        ).check(
7576*da0073e9SAndroid Build Coastguard Worker            "tensor 'L['input']' size mismatch at index 0. expected 3, actual 4"
7577*da0073e9SAndroid Build Coastguard Worker        ).run(
7578*da0073e9SAndroid Build Coastguard Worker            prof.report()
7579*da0073e9SAndroid Build Coastguard Worker        )
7580*da0073e9SAndroid Build Coastguard Worker
7581*da0073e9SAndroid Build Coastguard Worker    def test_guards_strip_function_call(self):
7582*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.guards import strip_function_call
7583*da0073e9SAndroid Build Coastguard Worker
7584*da0073e9SAndroid Build Coastguard Worker        test_case = [
7585*da0073e9SAndroid Build Coastguard Worker            ("___odict_getitem(a, 1)", "a"),
7586*da0073e9SAndroid Build Coastguard Worker            ("a.layers[slice(2)][0]._xyz", "a"),
7587*da0073e9SAndroid Build Coastguard Worker            ("getattr(a.layers[slice(2)][0]._abc, '0')", "a"),
7588*da0073e9SAndroid Build Coastguard Worker            ("getattr(getattr(a.x[3], '0'), '3')", "a"),
7589*da0073e9SAndroid Build Coastguard Worker            ("a.layers[slice(None, -1, None)][0]._xyz", "a"),
7590*da0073e9SAndroid Build Coastguard Worker            ("a.layers[func('offset', -1, None)][0]._xyz", "a"),
7591*da0073e9SAndroid Build Coastguard Worker        ]
7592*da0073e9SAndroid Build Coastguard Worker        # strip_function_call should extract the object from the string.
7593*da0073e9SAndroid Build Coastguard Worker        for name, expect_obj in test_case:
7594*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(strip_function_call(name), expect_obj)
7595*da0073e9SAndroid Build Coastguard Worker
7596*da0073e9SAndroid Build Coastguard Worker    def test_int_neg(self):
7597*da0073e9SAndroid Build Coastguard Worker        def int_neg(a, b):
7598*da0073e9SAndroid Build Coastguard Worker            x = a.shape[0]
7599*da0073e9SAndroid Build Coastguard Worker            y = b.shape[0]
7600*da0073e9SAndroid Build Coastguard Worker            return -x * -y * a * b
7601*da0073e9SAndroid Build Coastguard Worker
7602*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.testing.standard_test(self, int_neg, 2)
7603*da0073e9SAndroid Build Coastguard Worker
7604*da0073e9SAndroid Build Coastguard Worker    def test_hash_getitem_slice(self):
7605*da0073e9SAndroid Build Coastguard Worker        s = GetItemSource(LocalSource("foo"), slice(None, -1, None))
7606*da0073e9SAndroid Build Coastguard Worker        s2 = GetItemSource(LocalSource("foo"), slice(None, -1, None))
7607*da0073e9SAndroid Build Coastguard Worker        s3 = GetItemSource(LocalSource("foo"), slice(None, -1, 2))
7608*da0073e9SAndroid Build Coastguard Worker        some_set = set()
7609*da0073e9SAndroid Build Coastguard Worker
7610*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s not in some_set)
7611*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s2 not in some_set)
7612*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s3 not in some_set)
7613*da0073e9SAndroid Build Coastguard Worker
7614*da0073e9SAndroid Build Coastguard Worker        some_set.add(s)
7615*da0073e9SAndroid Build Coastguard Worker
7616*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s in some_set)
7617*da0073e9SAndroid Build Coastguard Worker        # s and s2 should hash the  same
7618*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s2 in some_set)
7619*da0073e9SAndroid Build Coastguard Worker        # s3 should be different
7620*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s3 not in some_set)
7621*da0073e9SAndroid Build Coastguard Worker
7622*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s == s2)
7623*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s != s3)
7624*da0073e9SAndroid Build Coastguard Worker
7625*da0073e9SAndroid Build Coastguard Worker    def test_inline_dict_function(self):
7626*da0073e9SAndroid Build Coastguard Worker        def _result_type_dict(dtype):
7627*da0073e9SAndroid Build Coastguard Worker            return {bool: torch.float32}[dtype]
7628*da0073e9SAndroid Build Coastguard Worker
7629*da0073e9SAndroid Build Coastguard Worker        @torch.compile
7630*da0073e9SAndroid Build Coastguard Worker        def f():
7631*da0073e9SAndroid Build Coastguard Worker            return torch.ones(3, dtype=_result_type_dict(bool))
7632*da0073e9SAndroid Build Coastguard Worker
7633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(), torch.ones(3, dtype=torch.float32))
7634*da0073e9SAndroid Build Coastguard Worker
7635*da0073e9SAndroid Build Coastguard Worker    def test_inline_dict_function_passed_as_arg(self):
7636*da0073e9SAndroid Build Coastguard Worker        @torch.compile
7637*da0073e9SAndroid Build Coastguard Worker        def fn(d, x, y):
7638*da0073e9SAndroid Build Coastguard Worker            if d[x] is torch.float32:
7639*da0073e9SAndroid Build Coastguard Worker                return y.cos()
7640*da0073e9SAndroid Build Coastguard Worker            else:
7641*da0073e9SAndroid Build Coastguard Worker                return y.sin()
7642*da0073e9SAndroid Build Coastguard Worker
7643*da0073e9SAndroid Build Coastguard Worker        dd = {bool: torch.float32, int: torch.int64}
7644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(dd, bool, torch.ones(4)), torch.ones(4).cos())
7645*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(dd, int, torch.ones(4)), torch.ones(4).sin())
7646*da0073e9SAndroid Build Coastguard Worker
7647*da0073e9SAndroid Build Coastguard Worker    def test_add_sizes(self):
7648*da0073e9SAndroid Build Coastguard Worker        def func(x):
7649*da0073e9SAndroid Build Coastguard Worker            y = x.size()
7650*da0073e9SAndroid Build Coastguard Worker            return y + y
7651*da0073e9SAndroid Build Coastguard Worker
7652*da0073e9SAndroid Build Coastguard Worker        eager_out = func(torch.ones(10, 10, 3))
7653*da0073e9SAndroid Build Coastguard Worker        compile_out = torch._dynamo.optimize("eager")(func)(torch.ones(10, 10, 3))
7654*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(compile_out, torch.Size))
7655*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_out, compile_out)
7656*da0073e9SAndroid Build Coastguard Worker
7657*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU")
7658*da0073e9SAndroid Build Coastguard Worker    def test_cuda_set_device(self):
7659*da0073e9SAndroid Build Coastguard Worker        def fn():
7660*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(2, device="cuda")
7661*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_device(1)
7662*da0073e9SAndroid Build Coastguard Worker            return a + 1
7663*da0073e9SAndroid Build Coastguard Worker
7664*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.device(0):
7665*da0073e9SAndroid Build Coastguard Worker            counter = CompileCounter()
7666*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize(counter)(fn)
7667*da0073e9SAndroid Build Coastguard Worker            res = opt_fn()
7668*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.device.type, "cuda")
7669*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.device.index, 0)
7670*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter.frame_count, 2)
7671*da0073e9SAndroid Build Coastguard Worker
7672*da0073e9SAndroid Build Coastguard Worker    def test_nested_function_resuming_with_correct_globals(self):
7673*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/99665
7674*da0073e9SAndroid Build Coastguard Worker        try:
7675*da0073e9SAndroid Build Coastguard Worker            from .utils import outer_func
7676*da0073e9SAndroid Build Coastguard Worker        except ImportError:
7677*da0073e9SAndroid Build Coastguard Worker            from utils import outer_func
7678*da0073e9SAndroid Build Coastguard Worker
7679*da0073e9SAndroid Build Coastguard Worker        def gn(x, y):
7680*da0073e9SAndroid Build Coastguard Worker            return x + y
7681*da0073e9SAndroid Build Coastguard Worker
7682*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
7683*da0073e9SAndroid Build Coastguard Worker            return outer_func(gn)(x, y)
7684*da0073e9SAndroid Build Coastguard Worker
7685*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([3])
7686*da0073e9SAndroid Build Coastguard Worker        y = torch.rand([3])
7687*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager")(fn)
7688*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
7689*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
7690*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
7691*da0073e9SAndroid Build Coastguard Worker
7692*da0073e9SAndroid Build Coastguard Worker    @dataclasses.dataclass
7693*da0073e9SAndroid Build Coastguard Worker    class CSETestCase:
7694*da0073e9SAndroid Build Coastguard Worker        expr: str
7695*da0073e9SAndroid Build Coastguard Worker        preface: typing.List[str] = dataclasses.field(default_factory=list)
7696*da0073e9SAndroid Build Coastguard Worker        expected: typing.Optional[str] = None
7697*da0073e9SAndroid Build Coastguard Worker        expected_py38: typing.Optional[str] = None
7698*da0073e9SAndroid Build Coastguard Worker
7699*da0073e9SAndroid Build Coastguard Worker    def _is_py38(self) -> bool:
7700*da0073e9SAndroid Build Coastguard Worker        return sys.version_info[:2] <= (3, 8)
7701*da0073e9SAndroid Build Coastguard Worker
7702*da0073e9SAndroid Build Coastguard Worker    def _has_ast_unparse(self) -> bool:
7703*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.guards import HAS_UNPARSE_FUNCTIONS
7704*da0073e9SAndroid Build Coastguard Worker
7705*da0073e9SAndroid Build Coastguard Worker        return HAS_UNPARSE_FUNCTIONS
7706*da0073e9SAndroid Build Coastguard Worker
7707*da0073e9SAndroid Build Coastguard Worker    def test_guards_cse_pass_single(self):
7708*da0073e9SAndroid Build Coastguard Worker        if not self._has_ast_unparse():
7709*da0073e9SAndroid Build Coastguard Worker            if IS_FBCODE:
7710*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Needs astunparse or Python-3.9+")
7711*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("Needs astunparse or Python-3.9+")
7712*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.guards import PyExprCSEPass
7713*da0073e9SAndroid Build Coastguard Worker
7714*da0073e9SAndroid Build Coastguard Worker        testcase = self.CSETestCase
7715*da0073e9SAndroid Build Coastguard Worker        testcases = [
7716*da0073e9SAndroid Build Coastguard Worker            # Nothing gets CSE-d, since the only repeated sub-expression is 'x'.
7717*da0073e9SAndroid Build Coastguard Worker            # i.e. not a node type we are interested on.
7718*da0073e9SAndroid Build Coastguard Worker            testcase(expr="x[0].a"),
7719*da0073e9SAndroid Build Coastguard Worker            testcase(expr="x[1].a"),
7720*da0073e9SAndroid Build Coastguard Worker            testcase(expr="x[2].a"),
7721*da0073e9SAndroid Build Coastguard Worker            # 'a.b.c' gets CSE-d, since it's a sub-expression used more than 'PyExprCSEPass.USE_THRESHOLD'.
7722*da0073e9SAndroid Build Coastguard Worker            testcase(
7723*da0073e9SAndroid Build Coastguard Worker                expr="a.b.c[0].d.e",
7724*da0073e9SAndroid Build Coastguard Worker                preface=["_var0 = a.b", "_var1 = _var0.c"],
7725*da0073e9SAndroid Build Coastguard Worker                expected="_var1[0].d.e",
7726*da0073e9SAndroid Build Coastguard Worker            ),
7727*da0073e9SAndroid Build Coastguard Worker            testcase(expr="a.b.c[1].d.e", expected="_var1[1].d.e"),
7728*da0073e9SAndroid Build Coastguard Worker            testcase(expr="a.b.c[2].d.e", expected="_var1[2].d.e"),
7729*da0073e9SAndroid Build Coastguard Worker            # 'm.n[0]' gets CSE-d, since it is a sub-expression used more than 'PyExprCSEPass.USE_THRESHOLD'.
7730*da0073e9SAndroid Build Coastguard Worker            testcase(
7731*da0073e9SAndroid Build Coastguard Worker                expr="f(m.n[0], '0').x.y.z",
7732*da0073e9SAndroid Build Coastguard Worker                preface=["_var2 = m.n", "_var3 = _var2[0]"],
7733*da0073e9SAndroid Build Coastguard Worker                expected="f(_var3, '0').x.y.z",
7734*da0073e9SAndroid Build Coastguard Worker            ),
7735*da0073e9SAndroid Build Coastguard Worker            testcase(expr="f(m.n[0], '1').x.y.z", expected="f(_var3, '1').x.y.z"),
7736*da0073e9SAndroid Build Coastguard Worker            testcase(expr="f(m.n[0], '2').x.y.z", expected="f(_var3, '2').x.y.z"),
7737*da0073e9SAndroid Build Coastguard Worker            # The whole expressiong gets CSE-d, as well as all of its sub-expressions.
7738*da0073e9SAndroid Build Coastguard Worker            testcase(
7739*da0073e9SAndroid Build Coastguard Worker                expr="self.g(a, b).k",
7740*da0073e9SAndroid Build Coastguard Worker                preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"],
7741*da0073e9SAndroid Build Coastguard Worker                expected="_var6",
7742*da0073e9SAndroid Build Coastguard Worker            ),
7743*da0073e9SAndroid Build Coastguard Worker            testcase(expr="self.g(a, b).k", expected="_var6"),
7744*da0073e9SAndroid Build Coastguard Worker            testcase(expr="self.g(a, b).k", expected="_var6"),
7745*da0073e9SAndroid Build Coastguard Worker        ]
7746*da0073e9SAndroid Build Coastguard Worker        csepass = PyExprCSEPass()
7747*da0073e9SAndroid Build Coastguard Worker        csepass.count([t.expr for t in testcases])
7748*da0073e9SAndroid Build Coastguard Worker
7749*da0073e9SAndroid Build Coastguard Worker        for t in testcases:
7750*da0073e9SAndroid Build Coastguard Worker            preface, expr = csepass.replace(t.expr)
7751*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(preface, t.preface)
7752*da0073e9SAndroid Build Coastguard Worker            expected = t.expected if t.expected is not None else t.expr
7753*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expr, expected)
7754*da0073e9SAndroid Build Coastguard Worker
7755*da0073e9SAndroid Build Coastguard Worker    def test_guards_cse_pass_multiple(self):
7756*da0073e9SAndroid Build Coastguard Worker        if not self._has_ast_unparse():
7757*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("Needs astunparse or Python-3.9+")
7758*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.guards import PyExprCSEPass
7759*da0073e9SAndroid Build Coastguard Worker
7760*da0073e9SAndroid Build Coastguard Worker        testcase = self.CSETestCase
7761*da0073e9SAndroid Build Coastguard Worker        testcases = [
7762*da0073e9SAndroid Build Coastguard Worker            testcase(
7763*da0073e9SAndroid Build Coastguard Worker                expr="x[0].a < x[1].a * (3 - x[2].a)",
7764*da0073e9SAndroid Build Coastguard Worker                expected="x[0].a < x[1].a * (3 - x[2].a)",
7765*da0073e9SAndroid Build Coastguard Worker                expected_py38="(x[0].a < (x[1].a * (3 - x[2].a)))",
7766*da0073e9SAndroid Build Coastguard Worker            ),
7767*da0073e9SAndroid Build Coastguard Worker            testcase(
7768*da0073e9SAndroid Build Coastguard Worker                expr="a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0",
7769*da0073e9SAndroid Build Coastguard Worker                preface=["_var0 = a.b", "_var1 = _var0.c"],
7770*da0073e9SAndroid Build Coastguard Worker                expected="_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0",
7771*da0073e9SAndroid Build Coastguard Worker                expected_py38="((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)",
7772*da0073e9SAndroid Build Coastguard Worker            ),
7773*da0073e9SAndroid Build Coastguard Worker            testcase(
7774*da0073e9SAndroid Build Coastguard Worker                expr="f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512",
7775*da0073e9SAndroid Build Coastguard Worker                preface=["_var2 = m.n", "_var3 = _var2[0]"],
7776*da0073e9SAndroid Build Coastguard Worker                expected="f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512",
7777*da0073e9SAndroid Build Coastguard Worker                expected_py38="(((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)",
7778*da0073e9SAndroid Build Coastguard Worker            ),
7779*da0073e9SAndroid Build Coastguard Worker            testcase(
7780*da0073e9SAndroid Build Coastguard Worker                expr="self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k",
7781*da0073e9SAndroid Build Coastguard Worker                preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"],
7782*da0073e9SAndroid Build Coastguard Worker                expected="_var6 + (1 - _var6) <= m[0].a + _var6",
7783*da0073e9SAndroid Build Coastguard Worker                expected_py38="((_var6 + (1 - _var6)) <= (m[0].a + _var6))",
7784*da0073e9SAndroid Build Coastguard Worker            ),
7785*da0073e9SAndroid Build Coastguard Worker        ]
7786*da0073e9SAndroid Build Coastguard Worker
7787*da0073e9SAndroid Build Coastguard Worker        csepass = PyExprCSEPass()
7788*da0073e9SAndroid Build Coastguard Worker        csepass.count([t.expr for t in testcases])
7789*da0073e9SAndroid Build Coastguard Worker
7790*da0073e9SAndroid Build Coastguard Worker        for t in testcases:
7791*da0073e9SAndroid Build Coastguard Worker            preface, expr = csepass.replace(t.expr)
7792*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(preface, t.preface)
7793*da0073e9SAndroid Build Coastguard Worker            expected = t.expected_py38 if self._is_py38() else t.expected
7794*da0073e9SAndroid Build Coastguard Worker            expected = expected if expected is not None else t.expr
7795*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expr, expected)
7796*da0073e9SAndroid Build Coastguard Worker
7797*da0073e9SAndroid Build Coastguard Worker    def test_guard_function_builder_with_cse(self):
7798*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.guards import build_guard_function
7799*da0073e9SAndroid Build Coastguard Worker
7800*da0073e9SAndroid Build Coastguard Worker        exprs = [
7801*da0073e9SAndroid Build Coastguard Worker            "x[0].a < x[1].a * (3 - x[2].a)",
7802*da0073e9SAndroid Build Coastguard Worker            "a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0",
7803*da0073e9SAndroid Build Coastguard Worker            "f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512",
7804*da0073e9SAndroid Build Coastguard Worker            "self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k",
7805*da0073e9SAndroid Build Coastguard Worker        ]
7806*da0073e9SAndroid Build Coastguard Worker
7807*da0073e9SAndroid Build Coastguard Worker        _, pycode = build_guard_function(exprs, "")
7808*da0073e9SAndroid Build Coastguard Worker        expected = """\
7809*da0073e9SAndroid Build Coastguard Workerdef ___make_guard_fn():
7810*da0073e9SAndroid Build Coastguard Worker    def guard(L):
7811*da0073e9SAndroid Build Coastguard Worker        if not (x[0].a < x[1].a * (3 - x[2].a)):
7812*da0073e9SAndroid Build Coastguard Worker            return False
7813*da0073e9SAndroid Build Coastguard Worker        _var0 = a.b
7814*da0073e9SAndroid Build Coastguard Worker        _var1 = _var0.c
7815*da0073e9SAndroid Build Coastguard Worker        if not (_var1[0].d.e + _var1[1].d.e * _var1[2].d.e > 0):
7816*da0073e9SAndroid Build Coastguard Worker            return False
7817*da0073e9SAndroid Build Coastguard Worker        _var2 = m.n
7818*da0073e9SAndroid Build Coastguard Worker        _var3 = _var2[0]
7819*da0073e9SAndroid Build Coastguard Worker        if not (f(_var3, '0').x.y.z * f(_var3, '1').x.y.z * f(_var3, '2').x.y.z < 512):
7820*da0073e9SAndroid Build Coastguard Worker            return False
7821*da0073e9SAndroid Build Coastguard Worker        _var4 = self.g
7822*da0073e9SAndroid Build Coastguard Worker        _var5 = _var4(a, b)
7823*da0073e9SAndroid Build Coastguard Worker        _var6 = _var5.k
7824*da0073e9SAndroid Build Coastguard Worker        if not (_var6 + (1 - _var6) <= m[0].a + _var6):
7825*da0073e9SAndroid Build Coastguard Worker            return False
7826*da0073e9SAndroid Build Coastguard Worker        return True
7827*da0073e9SAndroid Build Coastguard Worker    return guard
7828*da0073e9SAndroid Build Coastguard Worker"""
7829*da0073e9SAndroid Build Coastguard Worker        expected_38 = """\
7830*da0073e9SAndroid Build Coastguard Workerdef ___make_guard_fn():
7831*da0073e9SAndroid Build Coastguard Worker    def guard(L):
7832*da0073e9SAndroid Build Coastguard Worker        if not ((x[0].a < (x[1].a * (3 - x[2].a)))):
7833*da0073e9SAndroid Build Coastguard Worker            return False
7834*da0073e9SAndroid Build Coastguard Worker        _var0 = a.b
7835*da0073e9SAndroid Build Coastguard Worker        _var1 = _var0.c
7836*da0073e9SAndroid Build Coastguard Worker        if not (((_var1[0].d.e + (_var1[1].d.e * _var1[2].d.e)) > 0)):
7837*da0073e9SAndroid Build Coastguard Worker            return False
7838*da0073e9SAndroid Build Coastguard Worker        _var2 = m.n
7839*da0073e9SAndroid Build Coastguard Worker        _var3 = _var2[0]
7840*da0073e9SAndroid Build Coastguard Worker        if not ((((f(_var3, '0').x.y.z * f(_var3, '1').x.y.z) * f(_var3, '2').x.y.z) < 512)):
7841*da0073e9SAndroid Build Coastguard Worker            return False
7842*da0073e9SAndroid Build Coastguard Worker        _var4 = self.g
7843*da0073e9SAndroid Build Coastguard Worker        _var5 = _var4(a, b)
7844*da0073e9SAndroid Build Coastguard Worker        _var6 = _var5.k
7845*da0073e9SAndroid Build Coastguard Worker        if not (((_var6 + (1 - _var6)) <= (m[0].a + _var6))):
7846*da0073e9SAndroid Build Coastguard Worker            return False
7847*da0073e9SAndroid Build Coastguard Worker        return True
7848*da0073e9SAndroid Build Coastguard Worker    return guard
7849*da0073e9SAndroid Build Coastguard Worker"""
7850*da0073e9SAndroid Build Coastguard Worker        expected_38_no_astunparse = """\
7851*da0073e9SAndroid Build Coastguard Workerdef ___make_guard_fn():
7852*da0073e9SAndroid Build Coastguard Worker    def guard(L):
7853*da0073e9SAndroid Build Coastguard Worker        if not (x[0].a < x[1].a * (3 - x[2].a)):
7854*da0073e9SAndroid Build Coastguard Worker            return False
7855*da0073e9SAndroid Build Coastguard Worker        if not (a.b.c[0].d.e + a.b.c[1].d.e * a.b.c[2].d.e > 0):
7856*da0073e9SAndroid Build Coastguard Worker            return False
7857*da0073e9SAndroid Build Coastguard Worker        if not (f(m.n[0], '0').x.y.z * f(m.n[0], '1').x.y.z * f(m.n[0], '2').x.y.z < 512):
7858*da0073e9SAndroid Build Coastguard Worker            return False
7859*da0073e9SAndroid Build Coastguard Worker        if not (self.g(a, b).k + (1 - self.g(a, b).k) <= m[0].a + self.g(a, b).k):
7860*da0073e9SAndroid Build Coastguard Worker            return False
7861*da0073e9SAndroid Build Coastguard Worker        return True
7862*da0073e9SAndroid Build Coastguard Worker    return guard
7863*da0073e9SAndroid Build Coastguard Worker"""
7864*da0073e9SAndroid Build Coastguard Worker
7865*da0073e9SAndroid Build Coastguard Worker        if self._is_py38():
7866*da0073e9SAndroid Build Coastguard Worker            expected = (
7867*da0073e9SAndroid Build Coastguard Worker                expected_38 if self._has_ast_unparse() else expected_38_no_astunparse
7868*da0073e9SAndroid Build Coastguard Worker            )
7869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, pycode)
7870*da0073e9SAndroid Build Coastguard Worker
7871*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_compiling_fake_tensor_to_vararg_int(self):
7872*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
7873*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
7874*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7875*da0073e9SAndroid Build Coastguard Worker
7876*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
7877*da0073e9SAndroid Build Coastguard Worker                # use numpy int so it's wrapped as fake tensor in dynamo
7878*da0073e9SAndroid Build Coastguard Worker                shape = np.int_(16)
7879*da0073e9SAndroid Build Coastguard Worker                # test shape as fake tensor, which param type is
7880*da0073e9SAndroid Build Coastguard Worker                # Sequence[Union[_int, SymInt]]
7881*da0073e9SAndroid Build Coastguard Worker                return x.reshape(shape)
7882*da0073e9SAndroid Build Coastguard Worker
7883*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4, 4])
7884*da0073e9SAndroid Build Coastguard Worker        model = MyModule()
7885*da0073e9SAndroid Build Coastguard Worker        orig_out = model(x)
7886*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize("eager")(MyModule())
7887*da0073e9SAndroid Build Coastguard Worker        opt_out = opt_model(x)
7888*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(orig_out, opt_out))
7889*da0073e9SAndroid Build Coastguard Worker
7890*da0073e9SAndroid Build Coastguard Worker    def test_scalar_tensor_is_equivalent_to_symint_argument(self):
7891*da0073e9SAndroid Build Coastguard Worker        class GumbelTopKSampler(torch.nn.Module):
7892*da0073e9SAndroid Build Coastguard Worker            def __init__(self, T, k):
7893*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7894*da0073e9SAndroid Build Coastguard Worker                self.T = torch.nn.Parameter(
7895*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(T, dtype=torch.float32), requires_grad=False
7896*da0073e9SAndroid Build Coastguard Worker                )
7897*da0073e9SAndroid Build Coastguard Worker                self.k = torch.nn.Parameter(
7898*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(k, dtype=torch.int32), requires_grad=False
7899*da0073e9SAndroid Build Coastguard Worker                )
7900*da0073e9SAndroid Build Coastguard Worker
7901*da0073e9SAndroid Build Coastguard Worker            def sample_discrete(self, logits):
7902*da0073e9SAndroid Build Coastguard Worker                threshold = torch.topk(logits, self.k, sorted=True)[0][..., -1]
7903*da0073e9SAndroid Build Coastguard Worker                samples = torch.ge(logits.squeeze(1), threshold).float()
7904*da0073e9SAndroid Build Coastguard Worker                return samples
7905*da0073e9SAndroid Build Coastguard Worker
7906*da0073e9SAndroid Build Coastguard Worker            def forward(self, logits):
7907*da0073e9SAndroid Build Coastguard Worker                dsamples = self.sample_discrete(logits)
7908*da0073e9SAndroid Build Coastguard Worker                return dsamples
7909*da0073e9SAndroid Build Coastguard Worker
7910*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4, 4, 4, 4])
7911*da0073e9SAndroid Build Coastguard Worker        m = GumbelTopKSampler(T=4, k=4)
7912*da0073e9SAndroid Build Coastguard Worker        orig_out = m(x)
7913*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager")(m)
7914*da0073e9SAndroid Build Coastguard Worker        opt_out = opt_m(x)
7915*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(orig_out, opt_out))
7916*da0073e9SAndroid Build Coastguard Worker
7917*da0073e9SAndroid Build Coastguard Worker    def test_scalar_tensor_is_equivalent_to_symint_list_argument(self):
7918*da0073e9SAndroid Build Coastguard Worker        class Jitter(torch.nn.Module):
7919*da0073e9SAndroid Build Coastguard Worker            def __init__(self, jitter_val):
7920*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7921*da0073e9SAndroid Build Coastguard Worker                self.jitter_val = jitter_val
7922*da0073e9SAndroid Build Coastguard Worker
7923*da0073e9SAndroid Build Coastguard Worker            def roll_tensor(self, input):
7924*da0073e9SAndroid Build Coastguard Worker                h_shift = self.jitter_val - 1
7925*da0073e9SAndroid Build Coastguard Worker                w_shift = self.jitter_val + 1
7926*da0073e9SAndroid Build Coastguard Worker                return torch.roll(
7927*da0073e9SAndroid Build Coastguard Worker                    torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3
7928*da0073e9SAndroid Build Coastguard Worker                )
7929*da0073e9SAndroid Build Coastguard Worker
7930*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
7931*da0073e9SAndroid Build Coastguard Worker                return self.roll_tensor(input)
7932*da0073e9SAndroid Build Coastguard Worker
7933*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4, 4, 4, 4])
7934*da0073e9SAndroid Build Coastguard Worker        m = Jitter(jitter_val=4)
7935*da0073e9SAndroid Build Coastguard Worker        orig_out = m(x)
7936*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager")(m)
7937*da0073e9SAndroid Build Coastguard Worker        opt_out = opt_m(x)
7938*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(orig_out, opt_out))
7939*da0073e9SAndroid Build Coastguard Worker
7940*da0073e9SAndroid Build Coastguard Worker    def test_scalar_tensor_is_equivalent_to_int_list_argument(self):
7941*da0073e9SAndroid Build Coastguard Worker        class MyModel(torch.nn.Module):
7942*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
7943*da0073e9SAndroid Build Coastguard Worker                permute = torch.tensor([0, 2, 1])
7944*da0073e9SAndroid Build Coastguard Worker                x = input.permute(*permute)
7945*da0073e9SAndroid Build Coastguard Worker                return x
7946*da0073e9SAndroid Build Coastguard Worker
7947*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4)
7948*da0073e9SAndroid Build Coastguard Worker        m = MyModel()
7949*da0073e9SAndroid Build Coastguard Worker        orig_out = m(x)
7950*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager")(m)
7951*da0073e9SAndroid Build Coastguard Worker        opt_out = opt_m(x)
7952*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(orig_out, opt_out))
7953*da0073e9SAndroid Build Coastguard Worker
7954*da0073e9SAndroid Build Coastguard Worker    def test_torch_variable_hasattr(self):
7955*da0073e9SAndroid Build Coastguard Worker        def fn(x):
7956*da0073e9SAndroid Build Coastguard Worker            if hasattr(torch.nn, "Module"):
7957*da0073e9SAndroid Build Coastguard Worker                return x * x
7958*da0073e9SAndroid Build Coastguard Worker            return x + 1
7959*da0073e9SAndroid Build Coastguard Worker
7960*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
7961*da0073e9SAndroid Build Coastguard Worker
7962*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4, 4])
7963*da0073e9SAndroid Build Coastguard Worker        fn_out = fn(x)
7964*da0073e9SAndroid Build Coastguard Worker        compiled_out = compiled_fn(x)
7965*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn_out, compiled_out))
7966*da0073e9SAndroid Build Coastguard Worker
7967*da0073e9SAndroid Build Coastguard Worker    def test_list_hasattr1(self):
7968*da0073e9SAndroid Build Coastguard Worker        def fn(x):
7969*da0073e9SAndroid Build Coastguard Worker            if hasattr(x, "foo"):
7970*da0073e9SAndroid Build Coastguard Worker                return x[0] + 1
7971*da0073e9SAndroid Build Coastguard Worker            return x[0] - 1
7972*da0073e9SAndroid Build Coastguard Worker
7973*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
7974*da0073e9SAndroid Build Coastguard Worker
7975*da0073e9SAndroid Build Coastguard Worker        x = [torch.randn(3)]
7976*da0073e9SAndroid Build Coastguard Worker        fn_out = fn(x)
7977*da0073e9SAndroid Build Coastguard Worker        compiled_out = compiled_fn(x)
7978*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn_out, compiled_out))
7979*da0073e9SAndroid Build Coastguard Worker
7980*da0073e9SAndroid Build Coastguard Worker    def test_list_hasattr2(self):
7981*da0073e9SAndroid Build Coastguard Worker        def fn():
7982*da0073e9SAndroid Build Coastguard Worker            x = [torch.zeros(3)]
7983*da0073e9SAndroid Build Coastguard Worker            if hasattr(x, "__len__"):
7984*da0073e9SAndroid Build Coastguard Worker                return x[0] + 1
7985*da0073e9SAndroid Build Coastguard Worker            return x[0] - 1
7986*da0073e9SAndroid Build Coastguard Worker
7987*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
7988*da0073e9SAndroid Build Coastguard Worker
7989*da0073e9SAndroid Build Coastguard Worker        fn_out = fn()
7990*da0073e9SAndroid Build Coastguard Worker        compiled_out = compiled_fn()
7991*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn_out, compiled_out))
7992*da0073e9SAndroid Build Coastguard Worker
7993*da0073e9SAndroid Build Coastguard Worker    def test_tuple_hasattr(self):
7994*da0073e9SAndroid Build Coastguard Worker        def fn(x):
7995*da0073e9SAndroid Build Coastguard Worker            if hasattr(x, "foo"):
7996*da0073e9SAndroid Build Coastguard Worker                return x[0] + 1
7997*da0073e9SAndroid Build Coastguard Worker            return x[1] - 1
7998*da0073e9SAndroid Build Coastguard Worker
7999*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
8000*da0073e9SAndroid Build Coastguard Worker
8001*da0073e9SAndroid Build Coastguard Worker        x = (torch.randn(3), torch.randn(3))
8002*da0073e9SAndroid Build Coastguard Worker        fn_out = fn(x)
8003*da0073e9SAndroid Build Coastguard Worker        compiled_out = compiled_fn(x)
8004*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn_out, compiled_out))
8005*da0073e9SAndroid Build Coastguard Worker
8006*da0073e9SAndroid Build Coastguard Worker    def test_fn_hasattr__name__1(self):
8007*da0073e9SAndroid Build Coastguard Worker        def fn():
8008*da0073e9SAndroid Build Coastguard Worker            foo = lambda x: x + 1
8009*da0073e9SAndroid Build Coastguard Worker            return hasattr(foo, "__name__")
8010*da0073e9SAndroid Build Coastguard Worker
8011*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
8012*da0073e9SAndroid Build Coastguard Worker
8013*da0073e9SAndroid Build Coastguard Worker        fn_out = fn()
8014*da0073e9SAndroid Build Coastguard Worker        compiled_out = compiled_fn()
8015*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn_out, compiled_out)
8016*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fn_out)
8017*da0073e9SAndroid Build Coastguard Worker
8018*da0073e9SAndroid Build Coastguard Worker    def test_fn_hasattr__name__2(self):
8019*da0073e9SAndroid Build Coastguard Worker        def bar(x):
8020*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
8021*da0073e9SAndroid Build Coastguard Worker
8022*da0073e9SAndroid Build Coastguard Worker        def fn():
8023*da0073e9SAndroid Build Coastguard Worker            return hasattr(bar, "__name__")
8024*da0073e9SAndroid Build Coastguard Worker
8025*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
8026*da0073e9SAndroid Build Coastguard Worker
8027*da0073e9SAndroid Build Coastguard Worker        fn_out = fn()
8028*da0073e9SAndroid Build Coastguard Worker        compiled_out = compiled_fn()
8029*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn_out, compiled_out)
8030*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fn_out)
8031*da0073e9SAndroid Build Coastguard Worker
8032*da0073e9SAndroid Build Coastguard Worker    def test_fn_hasattr__name__3(self):
8033*da0073e9SAndroid Build Coastguard Worker        def bar(x, y):
8034*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x) + torch.cos(y)
8035*da0073e9SAndroid Build Coastguard Worker
8036*da0073e9SAndroid Build Coastguard Worker        baz = functools.partial(bar, y=4)
8037*da0073e9SAndroid Build Coastguard Worker
8038*da0073e9SAndroid Build Coastguard Worker        def fn():
8039*da0073e9SAndroid Build Coastguard Worker            return hasattr(baz, "__name__")
8040*da0073e9SAndroid Build Coastguard Worker
8041*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn)
8042*da0073e9SAndroid Build Coastguard Worker
8043*da0073e9SAndroid Build Coastguard Worker        fn_out = fn()
8044*da0073e9SAndroid Build Coastguard Worker        compiled_out = compiled_fn()
8045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn_out, compiled_out)
8046*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(fn_out)
8047*da0073e9SAndroid Build Coastguard Worker
8048*da0073e9SAndroid Build Coastguard Worker    def test_torch_objects_as_keys(self):
8049*da0073e9SAndroid Build Coastguard Worker        remap = {torch.float16: torch.float32}
8050*da0073e9SAndroid Build Coastguard Worker
8051*da0073e9SAndroid Build Coastguard Worker        def fn():
8052*da0073e9SAndroid Build Coastguard Worker            return torch.randn(3, dtype=remap[torch.float16])
8053*da0073e9SAndroid Build Coastguard Worker
8054*da0073e9SAndroid Build Coastguard Worker        opt = torch._dynamo.optimize("eager")(fn)
8055*da0073e9SAndroid Build Coastguard Worker        opt()
8056*da0073e9SAndroid Build Coastguard Worker
8057*da0073e9SAndroid Build Coastguard Worker    def test_tracing_py_tree(self):
8058*da0073e9SAndroid Build Coastguard Worker        def fn(xs):
8059*da0073e9SAndroid Build Coastguard Worker            flat_xs, spec = pytree.tree_flatten(xs)
8060*da0073e9SAndroid Build Coastguard Worker            res = [x.clone() for x in flat_xs]
8061*da0073e9SAndroid Build Coastguard Worker            return pytree.tree_unflatten(res, spec)
8062*da0073e9SAndroid Build Coastguard Worker
8063*da0073e9SAndroid Build Coastguard Worker        xs = [torch.tensor(i) for i in range(3)]
8064*da0073e9SAndroid Build Coastguard Worker
8065*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8066*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter, nopython=True)(fn)(xs)
8067*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8068*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 3)
8069*da0073e9SAndroid Build Coastguard Worker
8070*da0073e9SAndroid Build Coastguard Worker    def test_tracing_nested_py_tree(self):
8071*da0073e9SAndroid Build Coastguard Worker        import torch.utils._pytree as pytree
8072*da0073e9SAndroid Build Coastguard Worker
8073*da0073e9SAndroid Build Coastguard Worker        def fn(xs):
8074*da0073e9SAndroid Build Coastguard Worker            flat_xs, spec = pytree.tree_flatten(xs)
8075*da0073e9SAndroid Build Coastguard Worker            res = [x.clone() for x in flat_xs]
8076*da0073e9SAndroid Build Coastguard Worker            return pytree.tree_unflatten(res, spec)
8077*da0073e9SAndroid Build Coastguard Worker
8078*da0073e9SAndroid Build Coastguard Worker        xs = [torch.tensor(i) for i in range(3)]
8079*da0073e9SAndroid Build Coastguard Worker        xsl = [xs, xs, xs, xs]
8080*da0073e9SAndroid Build Coastguard Worker
8081*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8082*da0073e9SAndroid Build Coastguard Worker        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl)
8083*da0073e9SAndroid Build Coastguard Worker        real_out = fn(xsl)
8084*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(comp_out, real_out)
8085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 12)
8087*da0073e9SAndroid Build Coastguard Worker
8088*da0073e9SAndroid Build Coastguard Worker    def test_tracing_nested_py_tree_tuples(self):
8089*da0073e9SAndroid Build Coastguard Worker        import torch.utils._pytree as pytree
8090*da0073e9SAndroid Build Coastguard Worker
8091*da0073e9SAndroid Build Coastguard Worker        def fn(xs):
8092*da0073e9SAndroid Build Coastguard Worker            flat_xs, spec = pytree.tree_flatten(xs)
8093*da0073e9SAndroid Build Coastguard Worker            res = [x.clone() for x in flat_xs]
8094*da0073e9SAndroid Build Coastguard Worker            return pytree.tree_unflatten(res, spec)
8095*da0073e9SAndroid Build Coastguard Worker
8096*da0073e9SAndroid Build Coastguard Worker        xs = [torch.tensor(i) for i in range(3)]
8097*da0073e9SAndroid Build Coastguard Worker        xsl = (xs, xs, xs, xs)
8098*da0073e9SAndroid Build Coastguard Worker
8099*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8100*da0073e9SAndroid Build Coastguard Worker        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl)
8101*da0073e9SAndroid Build Coastguard Worker        real_out = fn(xsl)
8102*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(comp_out, real_out)
8103*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 12)
8105*da0073e9SAndroid Build Coastguard Worker
8106*da0073e9SAndroid Build Coastguard Worker    def test_tracing_nested_py_tree_dicts(self):
8107*da0073e9SAndroid Build Coastguard Worker        import torch.utils._pytree as pytree
8108*da0073e9SAndroid Build Coastguard Worker
8109*da0073e9SAndroid Build Coastguard Worker        def fn(xs):
8110*da0073e9SAndroid Build Coastguard Worker            flat_xs, spec = pytree.tree_flatten(xs)
8111*da0073e9SAndroid Build Coastguard Worker            res = [x.clone() for x in flat_xs]
8112*da0073e9SAndroid Build Coastguard Worker            return pytree.tree_unflatten(res, spec)
8113*da0073e9SAndroid Build Coastguard Worker
8114*da0073e9SAndroid Build Coastguard Worker        xs = [torch.tensor(i) for i in range(3)]
8115*da0073e9SAndroid Build Coastguard Worker        xsl = {
8116*da0073e9SAndroid Build Coastguard Worker            "a": xs,
8117*da0073e9SAndroid Build Coastguard Worker            "b": xs,
8118*da0073e9SAndroid Build Coastguard Worker            "c": xs,
8119*da0073e9SAndroid Build Coastguard Worker        }
8120*da0073e9SAndroid Build Coastguard Worker
8121*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8122*da0073e9SAndroid Build Coastguard Worker        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl)
8123*da0073e9SAndroid Build Coastguard Worker        real_out = fn(xsl)
8124*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(comp_out, real_out)
8125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 9)
8127*da0073e9SAndroid Build Coastguard Worker
8128*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_one_hot(self):
8129*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8130*da0073e9SAndroid Build Coastguard Worker            x = x + 1
8131*da0073e9SAndroid Build Coastguard Worker            # graph break from data-dependent output shape
8132*da0073e9SAndroid Build Coastguard Worker            x = torch.nn.functional.one_hot(x)
8133*da0073e9SAndroid Build Coastguard Worker            x = x + 1
8134*da0073e9SAndroid Build Coastguard Worker            return x
8135*da0073e9SAndroid Build Coastguard Worker
8136*da0073e9SAndroid Build Coastguard Worker        inp = torch.arange(20) % 4
8137*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8138*da0073e9SAndroid Build Coastguard Worker        real_out = fn(inp)
8139*da0073e9SAndroid Build Coastguard Worker        comp_out = torch.compile(fn, backend=counter)(inp)
8140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(comp_out, real_out)
8141*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
8142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 2)
8143*da0073e9SAndroid Build Coastguard Worker
8144*da0073e9SAndroid Build Coastguard Worker    def test_tracing_nested_py_tree_mixed_all(self):
8145*da0073e9SAndroid Build Coastguard Worker        import torch.utils._pytree as pytree
8146*da0073e9SAndroid Build Coastguard Worker
8147*da0073e9SAndroid Build Coastguard Worker        def fn(xs):
8148*da0073e9SAndroid Build Coastguard Worker            flat_xs, spec = pytree.tree_flatten(xs)
8149*da0073e9SAndroid Build Coastguard Worker            res = [x.clone() for x in flat_xs]
8150*da0073e9SAndroid Build Coastguard Worker            return pytree.tree_unflatten(res, spec)
8151*da0073e9SAndroid Build Coastguard Worker
8152*da0073e9SAndroid Build Coastguard Worker        xs = [torch.tensor(i) for i in range(3)]
8153*da0073e9SAndroid Build Coastguard Worker        xsa = (xs, xs)
8154*da0073e9SAndroid Build Coastguard Worker        xsb = {"aa": xsa, "ab": xs}
8155*da0073e9SAndroid Build Coastguard Worker        xsl = {
8156*da0073e9SAndroid Build Coastguard Worker            "a": xs,
8157*da0073e9SAndroid Build Coastguard Worker            "b": xsa,
8158*da0073e9SAndroid Build Coastguard Worker            "c": xsb,
8159*da0073e9SAndroid Build Coastguard Worker        }
8160*da0073e9SAndroid Build Coastguard Worker
8161*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8162*da0073e9SAndroid Build Coastguard Worker        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsl)
8163*da0073e9SAndroid Build Coastguard Worker        real_out = fn(xsl)
8164*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(comp_out, real_out)
8165*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 18)
8167*da0073e9SAndroid Build Coastguard Worker
8168*da0073e9SAndroid Build Coastguard Worker    def test_any_all_symnode(self):
8169*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
8170*da0073e9SAndroid Build Coastguard Worker
8171*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True, dynamic=True)
8172*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8173*da0073e9SAndroid Build Coastguard Worker            t = x.size(0) >= 10
8174*da0073e9SAndroid Build Coastguard Worker            f = x.size(0) >= 100
8175*da0073e9SAndroid Build Coastguard Worker            if any([]) or any([f]) or any([f, f]):
8176*da0073e9SAndroid Build Coastguard Worker                return x - 1
8177*da0073e9SAndroid Build Coastguard Worker            if all([f]) or all([t, f]) or all([f, t]) or all([f, f]):
8178*da0073e9SAndroid Build Coastguard Worker                return x - 2
8179*da0073e9SAndroid Build Coastguard Worker            if not (all([]) and all([t]) and all([t, t])):
8180*da0073e9SAndroid Build Coastguard Worker                return x - 3
8181*da0073e9SAndroid Build Coastguard Worker            if not (any([t]) and any([t, f]) and any([f, t])):
8182*da0073e9SAndroid Build Coastguard Worker                return x - 4
8183*da0073e9SAndroid Build Coastguard Worker            return x + 1
8184*da0073e9SAndroid Build Coastguard Worker
8185*da0073e9SAndroid Build Coastguard Worker        y1 = torch.randn(16)
8186*da0073e9SAndroid Build Coastguard Worker        y2 = torch.randn(18)
8187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(y1), y1 + 1)
8188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(y2), y2 + 1)
8189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
8190*da0073e9SAndroid Build Coastguard Worker        y3 = torch.randn(5)
8191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(y3), y3 - 3)
8192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
8193*da0073e9SAndroid Build Coastguard Worker
8194*da0073e9SAndroid Build Coastguard Worker    def test_tracing_py_tree_tensor_subclass(self):
8195*da0073e9SAndroid Build Coastguard Worker        import torch.utils._pytree as pytree
8196*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.two_tensor import TwoTensor
8197*da0073e9SAndroid Build Coastguard Worker        from torch.utils.checkpoint import checkpoint
8198*da0073e9SAndroid Build Coastguard Worker
8199*da0073e9SAndroid Build Coastguard Worker        def fn(xs):
8200*da0073e9SAndroid Build Coastguard Worker            nested_xs = [[xs]]
8201*da0073e9SAndroid Build Coastguard Worker            flat_xs, spec = pytree.tree_flatten(xs)
8202*da0073e9SAndroid Build Coastguard Worker            return flat_xs[0].clone()
8203*da0073e9SAndroid Build Coastguard Worker
8204*da0073e9SAndroid Build Coastguard Worker        # use checkpoint to trigger a "sourceless" tensor subclass
8205*da0073e9SAndroid Build Coastguard Worker        def checkpoint_fn(xs):
8206*da0073e9SAndroid Build Coastguard Worker            return checkpoint(fn, xs, use_reentrant=True)
8207*da0073e9SAndroid Build Coastguard Worker
8208*da0073e9SAndroid Build Coastguard Worker        xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2))
8209*da0073e9SAndroid Build Coastguard Worker
8210*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8211*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(counter, nopython=True)(checkpoint_fn)(xs)
8212*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8213*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 2)
8214*da0073e9SAndroid Build Coastguard Worker
8215*da0073e9SAndroid Build Coastguard Worker    def test_tracing_tree_map_only(self):
8216*da0073e9SAndroid Build Coastguard Worker        import torch.utils._pytree as pytree
8217*da0073e9SAndroid Build Coastguard Worker
8218*da0073e9SAndroid Build Coastguard Worker        def fn(xs):
8219*da0073e9SAndroid Build Coastguard Worker            def mapper(x):
8220*da0073e9SAndroid Build Coastguard Worker                return x.clone()
8221*da0073e9SAndroid Build Coastguard Worker
8222*da0073e9SAndroid Build Coastguard Worker            y = pytree.tree_map_only(torch.Tensor, mapper, xs)
8223*da0073e9SAndroid Build Coastguard Worker            return y
8224*da0073e9SAndroid Build Coastguard Worker
8225*da0073e9SAndroid Build Coastguard Worker        xs = [torch.tensor(i) for i in range(3)] + ["hi"]
8226*da0073e9SAndroid Build Coastguard Worker        xsa = (xs, xs)
8227*da0073e9SAndroid Build Coastguard Worker        xsb = {"aa": xsa, "ab": xs}
8228*da0073e9SAndroid Build Coastguard Worker
8229*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8230*da0073e9SAndroid Build Coastguard Worker        comp_out = torch._dynamo.optimize(counter, nopython=True)(fn)(xsb)
8231*da0073e9SAndroid Build Coastguard Worker        real_out = fn(xsb)
8232*da0073e9SAndroid Build Coastguard Worker
8233*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(comp_out, real_out)
8234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 9)
8236*da0073e9SAndroid Build Coastguard Worker
8237*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(
8238*da0073e9SAndroid Build Coastguard Worker        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
8239*da0073e9SAndroid Build Coastguard Worker    )
8240*da0073e9SAndroid Build Coastguard Worker    def test_unbacked_symint(self):
8241*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
8242*da0073e9SAndroid Build Coastguard Worker        def f(lengths, values):
8243*da0073e9SAndroid Build Coastguard Worker            sizes = lengths.tolist()
8244*da0073e9SAndroid Build Coastguard Worker            for s in sizes:
8245*da0073e9SAndroid Build Coastguard Worker                torch._check_is_size(s)
8246*da0073e9SAndroid Build Coastguard Worker                torch._check(s >= 2)
8247*da0073e9SAndroid Build Coastguard Worker                torch._check(s <= 100)
8248*da0073e9SAndroid Build Coastguard Worker            return torch.split(values, sizes)
8249*da0073e9SAndroid Build Coastguard Worker
8250*da0073e9SAndroid Build Coastguard Worker        f(torch.tensor([2, 3, 4]), torch.randn(9))
8251*da0073e9SAndroid Build Coastguard Worker
8252*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(
8253*da0073e9SAndroid Build Coastguard Worker        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
8254*da0073e9SAndroid Build Coastguard Worker    )
8255*da0073e9SAndroid Build Coastguard Worker    def test_unbacked_auto_functionalize_op(self):
8256*da0073e9SAndroid Build Coastguard Worker        @torch.library.custom_op(
8257*da0073e9SAndroid Build Coastguard Worker            "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"]
8258*da0073e9SAndroid Build Coastguard Worker        )
8259*da0073e9SAndroid Build Coastguard Worker        def mk_image(decoder: Tensor) -> Tensor:
8260*da0073e9SAndroid Build Coastguard Worker            return torch.randn(2, 3, 4, 5)
8261*da0073e9SAndroid Build Coastguard Worker
8262*da0073e9SAndroid Build Coastguard Worker        @torch.library.register_fake("mylib::mk_image")
8263*da0073e9SAndroid Build Coastguard Worker        def _(decoder: Tensor) -> Tensor:
8264*da0073e9SAndroid Build Coastguard Worker            image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)]
8265*da0073e9SAndroid Build Coastguard Worker            return torch.empty(image_size)
8266*da0073e9SAndroid Build Coastguard Worker
8267*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
8268*da0073e9SAndroid Build Coastguard Worker        def f(x):
8269*da0073e9SAndroid Build Coastguard Worker            return torch.ops.mylib.mk_image.default(x)
8270*da0073e9SAndroid Build Coastguard Worker
8271*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(100, dtype=torch.int64)
8272*da0073e9SAndroid Build Coastguard Worker        f(x)
8273*da0073e9SAndroid Build Coastguard Worker
8274*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
8275*da0073e9SAndroid Build Coastguard Worker    def test_runtime_assert_replacement(self):
8276*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager")
8277*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
8278*da0073e9SAndroid Build Coastguard Worker            z = y.item()
8279*da0073e9SAndroid Build Coastguard Worker            torch._check(z == 3)
8280*da0073e9SAndroid Build Coastguard Worker            return x + z
8281*da0073e9SAndroid Build Coastguard Worker
8282*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4), torch.tensor([3]))
8283*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: fn(torch.randn(4), torch.tensor([4])))
8284*da0073e9SAndroid Build Coastguard Worker
8285*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
8286*da0073e9SAndroid Build Coastguard Worker    def test_cat_unbacked(self):
8287*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
8288*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
8289*da0073e9SAndroid Build Coastguard Worker            z = y.item()
8290*da0073e9SAndroid Build Coastguard Worker            return torch.cat([x, torch.ones(z)])
8291*da0073e9SAndroid Build Coastguard Worker
8292*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(2, 3), torch.tensor([0]))
8293*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
8294*da0073e9SAndroid Build Coastguard Worker            RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([1]))
8295*da0073e9SAndroid Build Coastguard Worker        )
8296*da0073e9SAndroid Build Coastguard Worker
8297*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(
8298*da0073e9SAndroid Build Coastguard Worker        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
8299*da0073e9SAndroid Build Coastguard Worker    )
8300*da0073e9SAndroid Build Coastguard Worker    def test_aot_autograd_propagate_unbacked_symints_shape(self):
8301*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager")
8302*da0073e9SAndroid Build Coastguard Worker        def f(x):
8303*da0073e9SAndroid Build Coastguard Worker            return torch.nonzero(x)
8304*da0073e9SAndroid Build Coastguard Worker
8305*da0073e9SAndroid Build Coastguard Worker        f(torch.tensor([1, 0, 3, 2, 0]))
8306*da0073e9SAndroid Build Coastguard Worker
8307*da0073e9SAndroid Build Coastguard Worker    def test_simple_set_usage(self):
8308*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
8309*da0073e9SAndroid Build Coastguard Worker            setty = {x, y}
8310*da0073e9SAndroid Build Coastguard Worker            return setty.pop() * setty.pop()
8311*da0073e9SAndroid Build Coastguard Worker
8312*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8313*da0073e9SAndroid Build Coastguard Worker        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
8314*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
8315*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10, 10)
8316*da0073e9SAndroid Build Coastguard Worker        foo(x, y)
8317*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8318*da0073e9SAndroid Build Coastguard Worker
8319*da0073e9SAndroid Build Coastguard Worker    def test_add_to_set(self):
8320*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
8321*da0073e9SAndroid Build Coastguard Worker            setty = set()
8322*da0073e9SAndroid Build Coastguard Worker            setty.add(x[0])
8323*da0073e9SAndroid Build Coastguard Worker            setty.add(x[1])
8324*da0073e9SAndroid Build Coastguard Worker            setty.add(x[2])
8325*da0073e9SAndroid Build Coastguard Worker            setty.add(y)
8326*da0073e9SAndroid Build Coastguard Worker            return y * len(setty)
8327*da0073e9SAndroid Build Coastguard Worker
8328*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
8329*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2)
8330*da0073e9SAndroid Build Coastguard Worker        eager_result = foo([x, x, x, x, y], y)
8331*da0073e9SAndroid Build Coastguard Worker
8332*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8333*da0073e9SAndroid Build Coastguard Worker        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
8334*da0073e9SAndroid Build Coastguard Worker        result = foo([x, x, x, x, y], y)
8335*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, eager_result)
8337*da0073e9SAndroid Build Coastguard Worker
8338*da0073e9SAndroid Build Coastguard Worker    def test_iter_set(self):
8339*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
8340*da0073e9SAndroid Build Coastguard Worker            setty = set()
8341*da0073e9SAndroid Build Coastguard Worker            for t in x:
8342*da0073e9SAndroid Build Coastguard Worker                setty.add(t)
8343*da0073e9SAndroid Build Coastguard Worker            return y * len(setty)
8344*da0073e9SAndroid Build Coastguard Worker
8345*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
8346*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2)
8347*da0073e9SAndroid Build Coastguard Worker        eager_result = foo([x, x, x, x, y], y)
8348*da0073e9SAndroid Build Coastguard Worker
8349*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8350*da0073e9SAndroid Build Coastguard Worker        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
8351*da0073e9SAndroid Build Coastguard Worker        result = foo([x, x, x, x, y], y)
8352*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8353*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, eager_result)
8354*da0073e9SAndroid Build Coastguard Worker
8355*da0073e9SAndroid Build Coastguard Worker    def test_input_set_graph_break(self):
8356*da0073e9SAndroid Build Coastguard Worker        def foo(x):
8357*da0073e9SAndroid Build Coastguard Worker            return x.pop() * x.pop()
8358*da0073e9SAndroid Build Coastguard Worker
8359*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
8360*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10, 10)
8361*da0073e9SAndroid Build Coastguard Worker
8362*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8363*da0073e9SAndroid Build Coastguard Worker
8364*da0073e9SAndroid Build Coastguard Worker        inp = {x, x, x, x, y, y}
8365*da0073e9SAndroid Build Coastguard Worker        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
8366*da0073e9SAndroid Build Coastguard Worker
8367*da0073e9SAndroid Build Coastguard Worker        # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part.
8368*da0073e9SAndroid Build Coastguard Worker        # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents)
8369*da0073e9SAndroid Build Coastguard Worker        # and so the guard story for the objects passed into input just isn't there atm.
8370*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
8371*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.exc.Unsupported,
8372*da0073e9SAndroid Build Coastguard Worker            "^call_method UserDefinedObjectVariable\\(set\\).*",
8373*da0073e9SAndroid Build Coastguard Worker        ):
8374*da0073e9SAndroid Build Coastguard Worker            foo(inp)
8375*da0073e9SAndroid Build Coastguard Worker
8376*da0073e9SAndroid Build Coastguard Worker        foo = torch._dynamo.optimize(counter, nopython=False)(foo)
8377*da0073e9SAndroid Build Coastguard Worker        foo(inp)
8378*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8379*da0073e9SAndroid Build Coastguard Worker
8380*da0073e9SAndroid Build Coastguard Worker    def test_reconstruct_set_across_graph_break(self):
8381*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
8382*da0073e9SAndroid Build Coastguard Worker            setty = set()
8383*da0073e9SAndroid Build Coastguard Worker            for t in x:
8384*da0073e9SAndroid Build Coastguard Worker                setty.add(t)
8385*da0073e9SAndroid Build Coastguard Worker            print("Break!")
8386*da0073e9SAndroid Build Coastguard Worker            return y * len(setty)
8387*da0073e9SAndroid Build Coastguard Worker
8388*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
8389*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2)
8390*da0073e9SAndroid Build Coastguard Worker
8391*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8392*da0073e9SAndroid Build Coastguard Worker        foo = torch._dynamo.optimize(counter)(foo)
8393*da0073e9SAndroid Build Coastguard Worker        result = foo([x, x, x, x, y], y)
8394*da0073e9SAndroid Build Coastguard Worker
8395*da0073e9SAndroid Build Coastguard Worker    def test_set_aliasing_recompiles(self):
8396*da0073e9SAndroid Build Coastguard Worker        g1 = torch.randn(10)
8397*da0073e9SAndroid Build Coastguard Worker        g2 = torch.randn(10)
8398*da0073e9SAndroid Build Coastguard Worker        g3 = torch.randn(10)
8399*da0073e9SAndroid Build Coastguard Worker        g4 = torch.randn(10)
8400*da0073e9SAndroid Build Coastguard Worker
8401*da0073e9SAndroid Build Coastguard Worker        def foo(a, b, c):
8402*da0073e9SAndroid Build Coastguard Worker            myset = {g1, a, b, c}
8403*da0073e9SAndroid Build Coastguard Worker            return a + len(myset)
8404*da0073e9SAndroid Build Coastguard Worker
8405*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8406*da0073e9SAndroid Build Coastguard Worker        foo = torch._dynamo.optimize(counter)(foo)
8407*da0073e9SAndroid Build Coastguard Worker        # first call with no aliasing
8408*da0073e9SAndroid Build Coastguard Worker        foo(g2, g3, g4)
8409*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8410*da0073e9SAndroid Build Coastguard Worker
8411*da0073e9SAndroid Build Coastguard Worker        # no aliasing again
8412*da0073e9SAndroid Build Coastguard Worker        foo(g3, g2, g4)
8413*da0073e9SAndroid Build Coastguard Worker        # assert no recompile
8414*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8415*da0073e9SAndroid Build Coastguard Worker
8416*da0073e9SAndroid Build Coastguard Worker        # aliasing changes, we should recompile
8417*da0073e9SAndroid Build Coastguard Worker        foo(g2, g2, g2)
8418*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
8419*da0073e9SAndroid Build Coastguard Worker
8420*da0073e9SAndroid Build Coastguard Worker        # same aliasing, different tensor
8421*da0073e9SAndroid Build Coastguard Worker        foo(g3, g3, g3)
8422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
8423*da0073e9SAndroid Build Coastguard Worker
8424*da0073e9SAndroid Build Coastguard Worker        # aliasing between global and arg, should recompile again
8425*da0073e9SAndroid Build Coastguard Worker        foo(g1, g1, g1)
8426*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 3)
8427*da0073e9SAndroid Build Coastguard Worker
8428*da0073e9SAndroid Build Coastguard Worker        # Reset
8429*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
8430*da0073e9SAndroid Build Coastguard Worker
8431*da0073e9SAndroid Build Coastguard Worker        # aliasing between global and arg, first call
8432*da0073e9SAndroid Build Coastguard Worker        foo(g1, g1, g1)
8433*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 4)
8434*da0073e9SAndroid Build Coastguard Worker
8435*da0073e9SAndroid Build Coastguard Worker        # same aliasing, different tensor, all local, recompile
8436*da0073e9SAndroid Build Coastguard Worker        foo(g3, g3, g3)
8437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 5)
8438*da0073e9SAndroid Build Coastguard Worker
8439*da0073e9SAndroid Build Coastguard Worker        # aliasing same tensor, we shouldn't recompile
8440*da0073e9SAndroid Build Coastguard Worker        foo(g2, g2, g2)
8441*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 5)
8442*da0073e9SAndroid Build Coastguard Worker
8443*da0073e9SAndroid Build Coastguard Worker        # No aliasing
8444*da0073e9SAndroid Build Coastguard Worker        foo(g2, g3, g4)
8445*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 6)
8446*da0073e9SAndroid Build Coastguard Worker
8447*da0073e9SAndroid Build Coastguard Worker        # No aliasing again
8448*da0073e9SAndroid Build Coastguard Worker        foo(g3, g2, g4)
8449*da0073e9SAndroid Build Coastguard Worker        # assert no recompile
8450*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 6)
8451*da0073e9SAndroid Build Coastguard Worker
8452*da0073e9SAndroid Build Coastguard Worker    def test_str_format_return1(self):
8453*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
8454*da0073e9SAndroid Build Coastguard Worker        def fn(img):
8455*da0073e9SAndroid Build Coastguard Worker            x = torch.sin(img)
8456*da0073e9SAndroid Build Coastguard Worker            y = f"shape {img.shape[-2:]} batch size {img.shape[0]}"
8457*da0073e9SAndroid Build Coastguard Worker            return img + x, y
8458*da0073e9SAndroid Build Coastguard Worker
8459*da0073e9SAndroid Build Coastguard Worker        img1 = torch.randn(1, 1, 8, 8)
8460*da0073e9SAndroid Build Coastguard Worker        res, msg = fn(img1)
8461*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1")
8462*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, img1 + torch.sin(img1))
8463*da0073e9SAndroid Build Coastguard Worker
8464*da0073e9SAndroid Build Coastguard Worker    def test_str_format_return2(self):
8465*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
8466*da0073e9SAndroid Build Coastguard Worker        def fn(img):
8467*da0073e9SAndroid Build Coastguard Worker            x = torch.sin(img)
8468*da0073e9SAndroid Build Coastguard Worker            y = "shape {} batch size {y:.2f}".format(img.shape[-2:], y=img.shape[0])
8469*da0073e9SAndroid Build Coastguard Worker            return img + x, y
8470*da0073e9SAndroid Build Coastguard Worker
8471*da0073e9SAndroid Build Coastguard Worker        img1 = torch.randn(1, 1, 8, 8)
8472*da0073e9SAndroid Build Coastguard Worker        res, msg = fn(img1)
8473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00")
8474*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, img1 + torch.sin(img1))
8475*da0073e9SAndroid Build Coastguard Worker
8476*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
8477*da0073e9SAndroid Build Coastguard Worker    def test_validate_outputs_unbacked(self):
8478*da0073e9SAndroid Build Coastguard Worker        class SillyCat(torch.autograd.Function):
8479*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8480*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x0, x1, i):
8481*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(i)
8482*da0073e9SAndroid Build Coastguard Worker                return torch.cat([x0, x1])
8483*da0073e9SAndroid Build Coastguard Worker
8484*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8485*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_out):
8486*da0073e9SAndroid Build Coastguard Worker                (i,) = ctx.saved_tensors
8487*da0073e9SAndroid Build Coastguard Worker                i0, i1 = i.tolist()
8488*da0073e9SAndroid Build Coastguard Worker                g_x0, g_x1 = grad_out.split([i0, i1])
8489*da0073e9SAndroid Build Coastguard Worker                return g_x0, g_x1, None
8490*da0073e9SAndroid Build Coastguard Worker
8491*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager", fullgraph=True)
8492*da0073e9SAndroid Build Coastguard Worker        def f(x, i):
8493*da0073e9SAndroid Build Coastguard Worker            i0, i1 = i.tolist()
8494*da0073e9SAndroid Build Coastguard Worker            x0, x1 = x.split([i0, i1])
8495*da0073e9SAndroid Build Coastguard Worker            return SillyCat.apply(x0, x1, i)
8496*da0073e9SAndroid Build Coastguard Worker
8497*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
8498*da0073e9SAndroid Build Coastguard Worker
8499*da0073e9SAndroid Build Coastguard Worker    def test_str_format_assert1(self):
8500*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
8501*da0073e9SAndroid Build Coastguard Worker        def fn(img):
8502*da0073e9SAndroid Build Coastguard Worker            x = torch.sin(img)
8503*da0073e9SAndroid Build Coastguard Worker            val = x.shape[-2:]
8504*da0073e9SAndroid Build Coastguard Worker            torch._assert(len(val) == 2, f"shape {img.shape}")
8505*da0073e9SAndroid Build Coastguard Worker            return img + x
8506*da0073e9SAndroid Build Coastguard Worker
8507*da0073e9SAndroid Build Coastguard Worker        img1 = torch.randn(1, 1, 8, 8)
8508*da0073e9SAndroid Build Coastguard Worker        res = fn(img1)
8509*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, img1 + torch.sin(img1))
8510*da0073e9SAndroid Build Coastguard Worker
8511*da0073e9SAndroid Build Coastguard Worker    def test_str_format_assert2(self):
8512*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
8513*da0073e9SAndroid Build Coastguard Worker
8514*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt)
8515*da0073e9SAndroid Build Coastguard Worker        def fn(img):
8516*da0073e9SAndroid Build Coastguard Worker            x = torch.sin(img)
8517*da0073e9SAndroid Build Coastguard Worker            torch._assert(
8518*da0073e9SAndroid Build Coastguard Worker                img.shape[-2] == 8 and img.shape[-1] == 16, f"shape {img.shape}"
8519*da0073e9SAndroid Build Coastguard Worker            )
8520*da0073e9SAndroid Build Coastguard Worker            return img + x
8521*da0073e9SAndroid Build Coastguard Worker
8522*da0073e9SAndroid Build Coastguard Worker        img1 = torch.randn(1, 3, 8, 16)
8523*da0073e9SAndroid Build Coastguard Worker        res = fn(img1)
8524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, img1 + torch.sin(img1))
8525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
8526*da0073e9SAndroid Build Coastguard Worker
8527*da0073e9SAndroid Build Coastguard Worker        # trigger a recompile and graph break
8528*da0073e9SAndroid Build Coastguard Worker        img2 = torch.randn(1, 3, 8, 15)
8529*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(AssertionError, lambda: fn(img2))
8530*da0073e9SAndroid Build Coastguard Worker
8531*da0073e9SAndroid Build Coastguard Worker    def test_tolist_scalar(self):
8532*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8533*da0073e9SAndroid Build Coastguard Worker            new_list = []
8534*da0073e9SAndroid Build Coastguard Worker            for i in x.tolist():
8535*da0073e9SAndroid Build Coastguard Worker                new_list.append(i * 4)
8536*da0073e9SAndroid Build Coastguard Worker            return new_list
8537*da0073e9SAndroid Build Coastguard Worker
8538*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([3])
8539*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
8540*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8541*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x)
8542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8544*da0073e9SAndroid Build Coastguard Worker
8545*da0073e9SAndroid Build Coastguard Worker    def test_tolist_1d(self):
8546*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8547*da0073e9SAndroid Build Coastguard Worker            new_list = []
8548*da0073e9SAndroid Build Coastguard Worker            for i in x.tolist():
8549*da0073e9SAndroid Build Coastguard Worker                new_list.append(i * 4)
8550*da0073e9SAndroid Build Coastguard Worker            return new_list
8551*da0073e9SAndroid Build Coastguard Worker
8552*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([2, 1])
8553*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
8554*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8555*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x)
8556*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8557*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8558*da0073e9SAndroid Build Coastguard Worker
8559*da0073e9SAndroid Build Coastguard Worker    def test_tolist_kd(self):
8560*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8561*da0073e9SAndroid Build Coastguard Worker            new_list = []
8562*da0073e9SAndroid Build Coastguard Worker            for i in x.tolist():
8563*da0073e9SAndroid Build Coastguard Worker                new_list.append(i * 4)
8564*da0073e9SAndroid Build Coastguard Worker            return new_list
8565*da0073e9SAndroid Build Coastguard Worker
8566*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[[2, 1], [2, 1], [2, 1]], [[2, 1], [2, 1], [2, 1]]])
8567*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
8568*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8569*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x)
8570*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8571*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8572*da0073e9SAndroid Build Coastguard Worker
8573*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "specialize_int", True)
8574*da0073e9SAndroid Build Coastguard Worker    def test_tolist_0d(self):
8575*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8576*da0073e9SAndroid Build Coastguard Worker            new_list = []
8577*da0073e9SAndroid Build Coastguard Worker            i = x.tolist()
8578*da0073e9SAndroid Build Coastguard Worker            new_list.append(i * 4)
8579*da0073e9SAndroid Build Coastguard Worker            return new_list
8580*da0073e9SAndroid Build Coastguard Worker
8581*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(42)
8582*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
8583*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8584*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter, nopython=True)(fn)(x)
8585*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8586*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8587*da0073e9SAndroid Build Coastguard Worker
8588*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "assume_static_by_default", False)
8589*da0073e9SAndroid Build Coastguard Worker    @patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
8590*da0073e9SAndroid Build Coastguard Worker    def test_tolist_kd_dynamic(self):
8591*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8592*da0073e9SAndroid Build Coastguard Worker            new_list = []
8593*da0073e9SAndroid Build Coastguard Worker            i = x.tolist()
8594*da0073e9SAndroid Build Coastguard Worker            new_list.append(i * 4)
8595*da0073e9SAndroid Build Coastguard Worker            return new_list
8596*da0073e9SAndroid Build Coastguard Worker
8597*da0073e9SAndroid Build Coastguard Worker        x = torch.randint(3, 5, [5, 5])
8598*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
8599*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8600*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(counter, nopython=True)(fn)
8601*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(x)
8602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8603*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8604*da0073e9SAndroid Build Coastguard Worker
8605*da0073e9SAndroid Build Coastguard Worker        # Value change, no recompiles
8606*da0073e9SAndroid Build Coastguard Worker        x = torch.randint(7, 9, [5, 5])
8607*da0073e9SAndroid Build Coastguard Worker        compiled_fn(x)
8608*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8609*da0073e9SAndroid Build Coastguard Worker
8610*da0073e9SAndroid Build Coastguard Worker        # Size change, forced recompiles
8611*da0073e9SAndroid Build Coastguard Worker        x = torch.randint(3, 5, [3, 3])
8612*da0073e9SAndroid Build Coastguard Worker        compiled_fn(x)
8613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
8614*da0073e9SAndroid Build Coastguard Worker
8615*da0073e9SAndroid Build Coastguard Worker    def test_tolist_float(self):
8616*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8617*da0073e9SAndroid Build Coastguard Worker            new_list = []
8618*da0073e9SAndroid Build Coastguard Worker            for i in x.tolist():
8619*da0073e9SAndroid Build Coastguard Worker                new_list.append(i * 4)
8620*da0073e9SAndroid Build Coastguard Worker            return new_list
8621*da0073e9SAndroid Build Coastguard Worker
8622*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
8623*da0073e9SAndroid Build Coastguard Worker            [[[2.0, 1.0], [2.0, 1.0], [2.0, 1.0]], [[2.0, 1.0], [2.0, 1.0], [2.0, 1.0]]]
8624*da0073e9SAndroid Build Coastguard Worker        )
8625*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
8626*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8627*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter)(fn)(x)
8628*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8629*da0073e9SAndroid Build Coastguard Worker        # Nothing to compile here
8630*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 0)
8631*da0073e9SAndroid Build Coastguard Worker
8632*da0073e9SAndroid Build Coastguard Worker    def test_inline_closure_not_loaded_by_parent(self):
8633*da0073e9SAndroid Build Coastguard Worker        def outer(a):
8634*da0073e9SAndroid Build Coastguard Worker            return a + 1
8635*da0073e9SAndroid Build Coastguard Worker
8636*da0073e9SAndroid Build Coastguard Worker        def indirect(x):
8637*da0073e9SAndroid Build Coastguard Worker            return direct(x)
8638*da0073e9SAndroid Build Coastguard Worker
8639*da0073e9SAndroid Build Coastguard Worker        def direct(x):
8640*da0073e9SAndroid Build Coastguard Worker            def deep2(c):
8641*da0073e9SAndroid Build Coastguard Worker                return outer(c)
8642*da0073e9SAndroid Build Coastguard Worker
8643*da0073e9SAndroid Build Coastguard Worker            def deep(c):
8644*da0073e9SAndroid Build Coastguard Worker                return deep2(c)
8645*da0073e9SAndroid Build Coastguard Worker
8646*da0073e9SAndroid Build Coastguard Worker            return deep(x)
8647*da0073e9SAndroid Build Coastguard Worker
8648*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
8649*da0073e9SAndroid Build Coastguard Worker        eager = indirect(x)
8650*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8651*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter)(indirect)(x)
8652*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8654*da0073e9SAndroid Build Coastguard Worker
8655*da0073e9SAndroid Build Coastguard Worker    def test_deque_input(self):
8656*da0073e9SAndroid Build Coastguard Worker        a = torch.randn([2, 3])
8657*da0073e9SAndroid Build Coastguard Worker        b = torch.randn([2, 3])
8658*da0073e9SAndroid Build Coastguard Worker        d1 = collections.deque([a, b])
8659*da0073e9SAndroid Build Coastguard Worker        d1.insert(0, "foo")
8660*da0073e9SAndroid Build Coastguard Worker
8661*da0073e9SAndroid Build Coastguard Worker        d2 = collections.deque([a, b])
8662*da0073e9SAndroid Build Coastguard Worker        d2.insert(0, "foo")
8663*da0073e9SAndroid Build Coastguard Worker
8664*da0073e9SAndroid Build Coastguard Worker        def fn(q):
8665*da0073e9SAndroid Build Coastguard Worker            a = q.pop()
8666*da0073e9SAndroid Build Coastguard Worker            b = q.pop()
8667*da0073e9SAndroid Build Coastguard Worker            return a * b
8668*da0073e9SAndroid Build Coastguard Worker
8669*da0073e9SAndroid Build Coastguard Worker        eager = fn(d1)
8670*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8671*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter)(fn)(d2)
8672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8673*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8674*da0073e9SAndroid Build Coastguard Worker
8675*da0073e9SAndroid Build Coastguard Worker    def test_deque_append_left(self):
8676*da0073e9SAndroid Build Coastguard Worker        d1 = collections.deque([10, 10])
8677*da0073e9SAndroid Build Coastguard Worker        d1.insert(0, "foo")
8678*da0073e9SAndroid Build Coastguard Worker
8679*da0073e9SAndroid Build Coastguard Worker        d2 = collections.deque([10, 10])
8680*da0073e9SAndroid Build Coastguard Worker        d2.insert(0, "foo")
8681*da0073e9SAndroid Build Coastguard Worker
8682*da0073e9SAndroid Build Coastguard Worker        def fn(q, a, b):
8683*da0073e9SAndroid Build Coastguard Worker            q.appendleft(a)
8684*da0073e9SAndroid Build Coastguard Worker            q.appendleft(b)
8685*da0073e9SAndroid Build Coastguard Worker            return q.popleft() * q.popleft()
8686*da0073e9SAndroid Build Coastguard Worker
8687*da0073e9SAndroid Build Coastguard Worker        a = torch.randn([3, 3])
8688*da0073e9SAndroid Build Coastguard Worker        b = torch.randn([3, 3])
8689*da0073e9SAndroid Build Coastguard Worker        eager = fn(d1, a, b)
8690*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8691*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter)(fn)(d2, a, b)
8692*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8693*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8694*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(compiled, torch.Tensor))
8695*da0073e9SAndroid Build Coastguard Worker
8696*da0073e9SAndroid Build Coastguard Worker    def test_yield_from(self):
8697*da0073e9SAndroid Build Coastguard Worker        def yield_from_fn(t_list, k):
8698*da0073e9SAndroid Build Coastguard Worker            def yield_from_gen(l):
8699*da0073e9SAndroid Build Coastguard Worker                l2 = [t * k for t in l]
8700*da0073e9SAndroid Build Coastguard Worker                yield from l2
8701*da0073e9SAndroid Build Coastguard Worker
8702*da0073e9SAndroid Build Coastguard Worker            return [t * k for t in yield_from_gen(t_list)]
8703*da0073e9SAndroid Build Coastguard Worker
8704*da0073e9SAndroid Build Coastguard Worker        t_list = [torch.randn([2, 3]) for _ in range(3)]
8705*da0073e9SAndroid Build Coastguard Worker        eager = yield_from_fn(t_list, 2)
8706*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8707*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2)
8708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8710*da0073e9SAndroid Build Coastguard Worker
8711*da0073e9SAndroid Build Coastguard Worker    def test_yield_from_in_a_loop(self):
8712*da0073e9SAndroid Build Coastguard Worker        def gen2():
8713*da0073e9SAndroid Build Coastguard Worker            yield 1
8714*da0073e9SAndroid Build Coastguard Worker
8715*da0073e9SAndroid Build Coastguard Worker        def gen1():
8716*da0073e9SAndroid Build Coastguard Worker            for value in range(5):
8717*da0073e9SAndroid Build Coastguard Worker                yield from gen2()
8718*da0073e9SAndroid Build Coastguard Worker
8719*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8720*da0073e9SAndroid Build Coastguard Worker            c = 0
8721*da0073e9SAndroid Build Coastguard Worker            for i in gen1():
8722*da0073e9SAndroid Build Coastguard Worker                c = c + i
8723*da0073e9SAndroid Build Coastguard Worker            return x + c
8724*da0073e9SAndroid Build Coastguard Worker
8725*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
8726*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(4)
8727*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
8728*da0073e9SAndroid Build Coastguard Worker
8729*da0073e9SAndroid Build Coastguard Worker    def test_yield_gen_and_from(self):
8730*da0073e9SAndroid Build Coastguard Worker        def populate_and_multiply_sequence(n, multiplier):
8731*da0073e9SAndroid Build Coastguard Worker            # Inline generator
8732*da0073e9SAndroid Build Coastguard Worker            def tensor_generator():
8733*da0073e9SAndroid Build Coastguard Worker                for i in range(n):
8734*da0073e9SAndroid Build Coastguard Worker                    yield torch.tensor([i])
8735*da0073e9SAndroid Build Coastguard Worker
8736*da0073e9SAndroid Build Coastguard Worker            # Use 'yield from' to iterate over tensors and multiply
8737*da0073e9SAndroid Build Coastguard Worker            t_list = [tensor * multiplier for tensor in tensor_generator()]
8738*da0073e9SAndroid Build Coastguard Worker
8739*da0073e9SAndroid Build Coastguard Worker            def yield_from_gen():
8740*da0073e9SAndroid Build Coastguard Worker                yield from t_list
8741*da0073e9SAndroid Build Coastguard Worker
8742*da0073e9SAndroid Build Coastguard Worker            return [t for t in yield_from_gen()]
8743*da0073e9SAndroid Build Coastguard Worker
8744*da0073e9SAndroid Build Coastguard Worker        multiplier = torch.tensor([10])
8745*da0073e9SAndroid Build Coastguard Worker        eager = populate_and_multiply_sequence(5, multiplier)
8746*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8747*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter)(populate_and_multiply_sequence)(
8748*da0073e9SAndroid Build Coastguard Worker            5, multiplier
8749*da0073e9SAndroid Build Coastguard Worker        )
8750*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8751*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
8752*da0073e9SAndroid Build Coastguard Worker
8753*da0073e9SAndroid Build Coastguard Worker    def test_yield_from_user_stop_iteration(self):
8754*da0073e9SAndroid Build Coastguard Worker        class MyIter:
8755*da0073e9SAndroid Build Coastguard Worker            def __init__(self, seq):
8756*da0073e9SAndroid Build Coastguard Worker                self.seq = seq
8757*da0073e9SAndroid Build Coastguard Worker                self.index = 0
8758*da0073e9SAndroid Build Coastguard Worker
8759*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
8760*da0073e9SAndroid Build Coastguard Worker                return self
8761*da0073e9SAndroid Build Coastguard Worker
8762*da0073e9SAndroid Build Coastguard Worker            def __next__(self):
8763*da0073e9SAndroid Build Coastguard Worker                self.index += 1
8764*da0073e9SAndroid Build Coastguard Worker                if self.index <= len(self.seq):
8765*da0073e9SAndroid Build Coastguard Worker                    return self.seq[self.index - 1]
8766*da0073e9SAndroid Build Coastguard Worker                raise StopIteration(self.index)
8767*da0073e9SAndroid Build Coastguard Worker
8768*da0073e9SAndroid Build Coastguard Worker        def yield_from_iter_fn(seq):
8769*da0073e9SAndroid Build Coastguard Worker            def gen(seq):
8770*da0073e9SAndroid Build Coastguard Worker                yield from MyIter(seq)
8771*da0073e9SAndroid Build Coastguard Worker
8772*da0073e9SAndroid Build Coastguard Worker            return [i for i in gen(seq)]
8773*da0073e9SAndroid Build Coastguard Worker
8774*da0073e9SAndroid Build Coastguard Worker        seq = [torch.randn([2, 3]) for _ in range(3)]
8775*da0073e9SAndroid Build Coastguard Worker        eager = yield_from_iter_fn(seq)
8776*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8777*da0073e9SAndroid Build Coastguard Worker        compiled = torch._dynamo.optimize(counter)(yield_from_iter_fn)(seq)
8778*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8779*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 0)
8780*da0073e9SAndroid Build Coastguard Worker
8781*da0073e9SAndroid Build Coastguard Worker    def test_yield_send_to_subgenerator_graph_break(self):
8782*da0073e9SAndroid Build Coastguard Worker        def subgenerator(tensor):
8783*da0073e9SAndroid Build Coastguard Worker            multiplier = yield
8784*da0073e9SAndroid Build Coastguard Worker            yield tensor * multiplier
8785*da0073e9SAndroid Build Coastguard Worker
8786*da0073e9SAndroid Build Coastguard Worker        def main_generator(t_list):
8787*da0073e9SAndroid Build Coastguard Worker            for tensor in t_list:
8788*da0073e9SAndroid Build Coastguard Worker                subgen = subgenerator(tensor)
8789*da0073e9SAndroid Build Coastguard Worker                next(subgen)
8790*da0073e9SAndroid Build Coastguard Worker                yield from subgen.send(torch.tensor([10]))
8791*da0073e9SAndroid Build Coastguard Worker
8792*da0073e9SAndroid Build Coastguard Worker        t_list = [torch.tensor([i]) for i in range(5)]
8793*da0073e9SAndroid Build Coastguard Worker        eager = list(main_generator(t_list))
8794*da0073e9SAndroid Build Coastguard Worker
8795*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8796*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(counter)(main_generator)
8797*da0073e9SAndroid Build Coastguard Worker        compiled = list(compiled_fn(t_list))
8798*da0073e9SAndroid Build Coastguard Worker
8799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
8800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 0)
8801*da0073e9SAndroid Build Coastguard Worker
8802*da0073e9SAndroid Build Coastguard Worker    def test_derpy_nn_module_usage(self):
8803*da0073e9SAndroid Build Coastguard Worker        def ff1(x):
8804*da0073e9SAndroid Build Coastguard Worker            self = mod1
8805*da0073e9SAndroid Build Coastguard Worker            return torch.sigmoid(self.mod2(x) + self.param1)
8806*da0073e9SAndroid Build Coastguard Worker
8807*da0073e9SAndroid Build Coastguard Worker        def ff2(x):
8808*da0073e9SAndroid Build Coastguard Worker            self = mod2
8809*da0073e9SAndroid Build Coastguard Worker            return torch.cos(torch.sin(x) * self.param2 + 10)
8810*da0073e9SAndroid Build Coastguard Worker
8811*da0073e9SAndroid Build Coastguard Worker        mod1 = torch.nn.Module()
8812*da0073e9SAndroid Build Coastguard Worker        mod2 = torch.nn.Module()
8813*da0073e9SAndroid Build Coastguard Worker        mod1.register_module("mod2", mod2)
8814*da0073e9SAndroid Build Coastguard Worker        mod1.register_parameter("param1", torch.nn.Parameter(torch.randn(10)))
8815*da0073e9SAndroid Build Coastguard Worker        mod1.forward = ff1
8816*da0073e9SAndroid Build Coastguard Worker        mod2.register_parameter("param2", torch.nn.Parameter(torch.randn(10)))
8817*da0073e9SAndroid Build Coastguard Worker        mod2.forward = ff2
8818*da0073e9SAndroid Build Coastguard Worker        mod1.eval()
8819*da0073e9SAndroid Build Coastguard Worker
8820*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
8821*da0073e9SAndroid Build Coastguard Worker        expected = mod1(x)
8822*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
8823*da0073e9SAndroid Build Coastguard Worker        actual = torch.compile(mod1, backend=counter, fullgraph=True)(x)
8824*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
8825*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 6)
8826*da0073e9SAndroid Build Coastguard Worker
8827*da0073e9SAndroid Build Coastguard Worker    def test_default_args_device_dtype(self):
8828*da0073e9SAndroid Build Coastguard Worker        class Foo:
8829*da0073e9SAndroid Build Coastguard Worker            def __init__(
8830*da0073e9SAndroid Build Coastguard Worker                self,
8831*da0073e9SAndroid Build Coastguard Worker                dtype: torch.dtype = torch.float16,
8832*da0073e9SAndroid Build Coastguard Worker                device: torch.device = torch.device("cpu"),
8833*da0073e9SAndroid Build Coastguard Worker            ) -> None:
8834*da0073e9SAndroid Build Coastguard Worker                self.value = torch.tensor(10, dtype=dtype, device=device)
8835*da0073e9SAndroid Build Coastguard Worker
8836*da0073e9SAndroid Build Coastguard Worker        def fn():
8837*da0073e9SAndroid Build Coastguard Worker            return Foo().value + 1
8838*da0073e9SAndroid Build Coastguard Worker
8839*da0073e9SAndroid Build Coastguard Worker        opt_func = torch._dynamo.optimize("eager", nopython=True)(fn)
8840*da0073e9SAndroid Build Coastguard Worker        ref = fn()
8841*da0073e9SAndroid Build Coastguard Worker        res = opt_func()
8842*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
8843*da0073e9SAndroid Build Coastguard Worker
8844*da0073e9SAndroid Build Coastguard Worker    def test_torch_device_python_type(self):
8845*da0073e9SAndroid Build Coastguard Worker        for device, device_type, index in [
8846*da0073e9SAndroid Build Coastguard Worker            ("cpu", "cpu", None),
8847*da0073e9SAndroid Build Coastguard Worker            ("cuda:0", "cuda", 0),
8848*da0073e9SAndroid Build Coastguard Worker        ]:
8849*da0073e9SAndroid Build Coastguard Worker            if device == "cuda:0" and not TEST_CUDA:
8850*da0073e9SAndroid Build Coastguard Worker                continue
8851*da0073e9SAndroid Build Coastguard Worker
8852*da0073e9SAndroid Build Coastguard Worker            def fn(target):
8853*da0073e9SAndroid Build Coastguard Worker                target_device = target.device
8854*da0073e9SAndroid Build Coastguard Worker                a = torch.zeros(2, 3, device=target_device)
8855*da0073e9SAndroid Build Coastguard Worker                # Constant assert at trace time
8856*da0073e9SAndroid Build Coastguard Worker                assert isinstance(target_device, torch.device)
8857*da0073e9SAndroid Build Coastguard Worker                assert target_device.type == device_type
8858*da0073e9SAndroid Build Coastguard Worker                assert target_device.index == index
8859*da0073e9SAndroid Build Coastguard Worker                b = torch.zeros(2, 3, device=target_device)
8860*da0073e9SAndroid Build Coastguard Worker                c = torch.zeros(2, 3, device=target_device)
8861*da0073e9SAndroid Build Coastguard Worker                return a + b + c
8862*da0073e9SAndroid Build Coastguard Worker
8863*da0073e9SAndroid Build Coastguard Worker            from torch._dynamo.variables import ConstantVariable
8864*da0073e9SAndroid Build Coastguard Worker
8865*da0073e9SAndroid Build Coastguard Worker            device = torch.device(device)
8866*da0073e9SAndroid Build Coastguard Worker            expected_variable = ConstantVariable(device)
8867*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_variable.python_type(), type(device))
8868*da0073e9SAndroid Build Coastguard Worker
8869*da0073e9SAndroid Build Coastguard Worker            opt_func = torch._dynamo.optimize("eager", nopython=True)(fn)
8870*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([2, 3], device=device)
8871*da0073e9SAndroid Build Coastguard Worker            res = opt_func(a)
8872*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(res, torch.Tensor)
8873*da0073e9SAndroid Build Coastguard Worker
8874*da0073e9SAndroid Build Coastguard Worker    def test_torch_dtype_python_type(self):
8875*da0073e9SAndroid Build Coastguard Worker        def fn(target):
8876*da0073e9SAndroid Build Coastguard Worker            target_dtype = target.dtype
8877*da0073e9SAndroid Build Coastguard Worker            a = torch.zeros(2, 3, dtype=target_dtype)
8878*da0073e9SAndroid Build Coastguard Worker            # Constant assert at trace time
8879*da0073e9SAndroid Build Coastguard Worker            assert isinstance(target_dtype, torch.dtype)
8880*da0073e9SAndroid Build Coastguard Worker            b = torch.zeros(2, 3, dtype=target_dtype)
8881*da0073e9SAndroid Build Coastguard Worker            c = torch.zeros(2, 3, dtype=target_dtype)
8882*da0073e9SAndroid Build Coastguard Worker            return a + b + c
8883*da0073e9SAndroid Build Coastguard Worker
8884*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.variables import ConstantVariable
8885*da0073e9SAndroid Build Coastguard Worker
8886*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float16
8887*da0073e9SAndroid Build Coastguard Worker        expected_variable = ConstantVariable(dtype)
8888*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_variable.python_type(), type(dtype))
8889*da0073e9SAndroid Build Coastguard Worker
8890*da0073e9SAndroid Build Coastguard Worker        opt_func = torch._dynamo.optimize("eager", nopython=True)(fn)
8891*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([2, 3], dtype=dtype)
8892*da0073e9SAndroid Build Coastguard Worker        res = opt_func(a)
8893*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(res, torch.Tensor)
8894*da0073e9SAndroid Build Coastguard Worker
8895*da0073e9SAndroid Build Coastguard Worker    def test_itertools_repeat(self):
8896*da0073e9SAndroid Build Coastguard Worker        counters.clear()
8897*da0073e9SAndroid Build Coastguard Worker
8898*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8899*da0073e9SAndroid Build Coastguard Worker            r = itertools.repeat(100.0, 5)
8900*da0073e9SAndroid Build Coastguard Worker            for i in r:
8901*da0073e9SAndroid Build Coastguard Worker                x += i
8902*da0073e9SAndroid Build Coastguard Worker            return x
8903*da0073e9SAndroid Build Coastguard Worker
8904*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([2, 5])
8905*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
8906*da0073e9SAndroid Build Coastguard Worker
8907*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
8908*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(x)
8909*da0073e9SAndroid Build Coastguard Worker
8910*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(eager), list(compiled))
8911*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 0)
8912*da0073e9SAndroid Build Coastguard Worker
8913*da0073e9SAndroid Build Coastguard Worker    def test_itertools_infinite_repeat(self):
8914*da0073e9SAndroid Build Coastguard Worker        counters.clear()
8915*da0073e9SAndroid Build Coastguard Worker
8916*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8917*da0073e9SAndroid Build Coastguard Worker            r = itertools.repeat(100.0)
8918*da0073e9SAndroid Build Coastguard Worker            idx = 0
8919*da0073e9SAndroid Build Coastguard Worker            for i in r:
8920*da0073e9SAndroid Build Coastguard Worker                x += i
8921*da0073e9SAndroid Build Coastguard Worker                idx += 1
8922*da0073e9SAndroid Build Coastguard Worker                if idx > 10:
8923*da0073e9SAndroid Build Coastguard Worker                    break
8924*da0073e9SAndroid Build Coastguard Worker            return x
8925*da0073e9SAndroid Build Coastguard Worker
8926*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([2, 5])
8927*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
8928*da0073e9SAndroid Build Coastguard Worker
8929*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
8930*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(x)
8931*da0073e9SAndroid Build Coastguard Worker
8932*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(eager), list(compiled))
8933*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 0)
8934*da0073e9SAndroid Build Coastguard Worker
8935*da0073e9SAndroid Build Coastguard Worker    def test_itertools_infinite_repeat_mutation(self):
8936*da0073e9SAndroid Build Coastguard Worker        counters.clear()
8937*da0073e9SAndroid Build Coastguard Worker
8938*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8939*da0073e9SAndroid Build Coastguard Worker            r = itertools.repeat(x)
8940*da0073e9SAndroid Build Coastguard Worker            idx = 0
8941*da0073e9SAndroid Build Coastguard Worker            for i in r:
8942*da0073e9SAndroid Build Coastguard Worker                x += i
8943*da0073e9SAndroid Build Coastguard Worker                i += 1
8944*da0073e9SAndroid Build Coastguard Worker                idx += 1
8945*da0073e9SAndroid Build Coastguard Worker                if idx > 10:
8946*da0073e9SAndroid Build Coastguard Worker                    break
8947*da0073e9SAndroid Build Coastguard Worker            return x
8948*da0073e9SAndroid Build Coastguard Worker
8949*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([2, 5])
8950*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
8951*da0073e9SAndroid Build Coastguard Worker
8952*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
8953*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(x)
8954*da0073e9SAndroid Build Coastguard Worker
8955*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(eager), list(compiled))
8956*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 0)
8957*da0073e9SAndroid Build Coastguard Worker
8958*da0073e9SAndroid Build Coastguard Worker    def test_itertools_infinite_count(self):
8959*da0073e9SAndroid Build Coastguard Worker        for args in ([], [10], [5, -1]):
8960*da0073e9SAndroid Build Coastguard Worker            counters.clear()
8961*da0073e9SAndroid Build Coastguard Worker
8962*da0073e9SAndroid Build Coastguard Worker            def fn(x):
8963*da0073e9SAndroid Build Coastguard Worker                r = itertools.count(*args)
8964*da0073e9SAndroid Build Coastguard Worker                idx = 0
8965*da0073e9SAndroid Build Coastguard Worker                for i in r:
8966*da0073e9SAndroid Build Coastguard Worker                    x += i
8967*da0073e9SAndroid Build Coastguard Worker                    idx += 1
8968*da0073e9SAndroid Build Coastguard Worker                    if idx > 10:
8969*da0073e9SAndroid Build Coastguard Worker                        break
8970*da0073e9SAndroid Build Coastguard Worker                return x
8971*da0073e9SAndroid Build Coastguard Worker
8972*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([2, 5])
8973*da0073e9SAndroid Build Coastguard Worker            eager = fn(x)
8974*da0073e9SAndroid Build Coastguard Worker
8975*da0073e9SAndroid Build Coastguard Worker            compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
8976*da0073e9SAndroid Build Coastguard Worker            compiled = compiled_fn(x)
8977*da0073e9SAndroid Build Coastguard Worker
8978*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(eager), list(compiled))
8979*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(counters["graph_break"]), 0)
8980*da0073e9SAndroid Build Coastguard Worker
8981*da0073e9SAndroid Build Coastguard Worker    def test_itertools_infinite_cycle(self):
8982*da0073e9SAndroid Build Coastguard Worker        counters.clear()
8983*da0073e9SAndroid Build Coastguard Worker
8984*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8985*da0073e9SAndroid Build Coastguard Worker            for iterator in (
8986*da0073e9SAndroid Build Coastguard Worker                iter([]),
8987*da0073e9SAndroid Build Coastguard Worker                iter([10, 11.0]),
8988*da0073e9SAndroid Build Coastguard Worker                itertools.repeat(-1, 3),
8989*da0073e9SAndroid Build Coastguard Worker                itertools.count(10),
8990*da0073e9SAndroid Build Coastguard Worker            ):
8991*da0073e9SAndroid Build Coastguard Worker                r = itertools.cycle(iterator)
8992*da0073e9SAndroid Build Coastguard Worker                idx = 0
8993*da0073e9SAndroid Build Coastguard Worker                x += 1
8994*da0073e9SAndroid Build Coastguard Worker                for i in r:
8995*da0073e9SAndroid Build Coastguard Worker                    x += i
8996*da0073e9SAndroid Build Coastguard Worker                    idx += 1
8997*da0073e9SAndroid Build Coastguard Worker                    if idx > 10:
8998*da0073e9SAndroid Build Coastguard Worker                        break
8999*da0073e9SAndroid Build Coastguard Worker            return x
9000*da0073e9SAndroid Build Coastguard Worker
9001*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([2, 5])
9002*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
9003*da0073e9SAndroid Build Coastguard Worker
9004*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9005*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(x)
9006*da0073e9SAndroid Build Coastguard Worker
9007*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(eager), list(compiled))
9008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 0)
9009*da0073e9SAndroid Build Coastguard Worker
9010*da0073e9SAndroid Build Coastguard Worker    def test_itertools_accumulate_symint_default_sum(self):
9011*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/110287
9012*da0073e9SAndroid Build Coastguard Worker        counters.clear()
9013*da0073e9SAndroid Build Coastguard Worker
9014*da0073e9SAndroid Build Coastguard Worker        def fn(x):
9015*da0073e9SAndroid Build Coastguard Worker            r = itertools.accumulate([x.size(0), x.size(1)])
9016*da0073e9SAndroid Build Coastguard Worker            for i in r:
9017*da0073e9SAndroid Build Coastguard Worker                x *= i
9018*da0073e9SAndroid Build Coastguard Worker            return x
9019*da0073e9SAndroid Build Coastguard Worker
9020*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
9021*da0073e9SAndroid Build Coastguard Worker        eager = fn(x)
9022*da0073e9SAndroid Build Coastguard Worker
9023*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9024*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(x)
9025*da0073e9SAndroid Build Coastguard Worker
9026*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(eager), list(compiled))
9027*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 0)
9028*da0073e9SAndroid Build Coastguard Worker
9029*da0073e9SAndroid Build Coastguard Worker    def test_itertools_accumulate_tensors_default_sum(self):
9030*da0073e9SAndroid Build Coastguard Worker        counters.clear()
9031*da0073e9SAndroid Build Coastguard Worker
9032*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d, x):
9033*da0073e9SAndroid Build Coastguard Worker            l = [a, b, c, d, x]
9034*da0073e9SAndroid Build Coastguard Worker            for i, t in enumerate(l):
9035*da0073e9SAndroid Build Coastguard Worker                l[i] = t * x
9036*da0073e9SAndroid Build Coastguard Worker            return itertools.accumulate(l)
9037*da0073e9SAndroid Build Coastguard Worker
9038*da0073e9SAndroid Build Coastguard Worker        t_list = [torch.tensor([i + 1]) for i in range(4)]
9039*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1, 2], [3, 4]])
9040*da0073e9SAndroid Build Coastguard Worker        eager = fn(*t_list, x)
9041*da0073e9SAndroid Build Coastguard Worker
9042*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9043*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(*t_list, x)
9044*da0073e9SAndroid Build Coastguard Worker
9045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(eager), list(compiled))
9046*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 0)
9047*da0073e9SAndroid Build Coastguard Worker
9048*da0073e9SAndroid Build Coastguard Worker    def test_itertools_accumulate_tensors_builtins(self):
9049*da0073e9SAndroid Build Coastguard Worker        for builtin_op in [operator.mul, operator.sub, operator.pow]:
9050*da0073e9SAndroid Build Coastguard Worker            counters.clear()
9051*da0073e9SAndroid Build Coastguard Worker
9052*da0073e9SAndroid Build Coastguard Worker            def fn(a, b, c, d, x):
9053*da0073e9SAndroid Build Coastguard Worker                l = [a, b, c, d, x]
9054*da0073e9SAndroid Build Coastguard Worker                for i, t in enumerate(l):
9055*da0073e9SAndroid Build Coastguard Worker                    l[i] = t * x
9056*da0073e9SAndroid Build Coastguard Worker                return itertools.accumulate(l, builtin_op)
9057*da0073e9SAndroid Build Coastguard Worker
9058*da0073e9SAndroid Build Coastguard Worker            t_list = [torch.tensor([i + 1]) for i in range(4)]
9059*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[1, 2], [3, 4]])
9060*da0073e9SAndroid Build Coastguard Worker            eager = fn(*t_list, x)
9061*da0073e9SAndroid Build Coastguard Worker
9062*da0073e9SAndroid Build Coastguard Worker            compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9063*da0073e9SAndroid Build Coastguard Worker            compiled = compiled_fn(*t_list, x)
9064*da0073e9SAndroid Build Coastguard Worker
9065*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(eager), list(compiled))
9066*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(counters["graph_break"]), 0)
9067*da0073e9SAndroid Build Coastguard Worker
9068*da0073e9SAndroid Build Coastguard Worker    def test_itertools_accumulate_tensors_kwargs(self):
9069*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.utils import counters
9070*da0073e9SAndroid Build Coastguard Worker
9071*da0073e9SAndroid Build Coastguard Worker        for kwargs in [
9072*da0073e9SAndroid Build Coastguard Worker            {"func": operator.mul},
9073*da0073e9SAndroid Build Coastguard Worker            {"initial": 100},
9074*da0073e9SAndroid Build Coastguard Worker            {"func": operator.sub, "initial": -1},
9075*da0073e9SAndroid Build Coastguard Worker        ]:
9076*da0073e9SAndroid Build Coastguard Worker            counters.clear()
9077*da0073e9SAndroid Build Coastguard Worker
9078*da0073e9SAndroid Build Coastguard Worker            def fn(a, b, c, d, x):
9079*da0073e9SAndroid Build Coastguard Worker                l = [a, b, c, d, x]
9080*da0073e9SAndroid Build Coastguard Worker                for i, t in enumerate(l):
9081*da0073e9SAndroid Build Coastguard Worker                    l[i] = t * x
9082*da0073e9SAndroid Build Coastguard Worker                return itertools.accumulate(l, **kwargs)
9083*da0073e9SAndroid Build Coastguard Worker
9084*da0073e9SAndroid Build Coastguard Worker            t_list = [torch.tensor([i + 1]) for i in range(4)]
9085*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[1, 2], [3, 4]])
9086*da0073e9SAndroid Build Coastguard Worker
9087*da0073e9SAndroid Build Coastguard Worker            compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9088*da0073e9SAndroid Build Coastguard Worker            compiled = compiled_fn(*t_list, x)
9089*da0073e9SAndroid Build Coastguard Worker            eager = fn(*t_list, x)
9090*da0073e9SAndroid Build Coastguard Worker
9091*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(eager), list(compiled))
9092*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(counters["graph_break"]), 0)
9093*da0073e9SAndroid Build Coastguard Worker
9094*da0073e9SAndroid Build Coastguard Worker    def test_packaging_version_parse(self):
9095*da0073e9SAndroid Build Coastguard Worker        from packaging import version
9096*da0073e9SAndroid Build Coastguard Worker
9097*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
9098*da0073e9SAndroid Build Coastguard Worker        def fn():
9099*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(1)
9100*da0073e9SAndroid Build Coastguard Worker            if version.parse(torch.__version__) >= version.parse("2.0.0"):
9101*da0073e9SAndroid Build Coastguard Worker                return x + 1
9102*da0073e9SAndroid Build Coastguard Worker            return x
9103*da0073e9SAndroid Build Coastguard Worker
9104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn().item(), 1)
9105*da0073e9SAndroid Build Coastguard Worker
9106*da0073e9SAndroid Build Coastguard Worker    def test_itertools_accumulate_tensors_user_defined(self):
9107*da0073e9SAndroid Build Coastguard Worker        def udo_fn_0(a, b):
9108*da0073e9SAndroid Build Coastguard Worker            return -1
9109*da0073e9SAndroid Build Coastguard Worker
9110*da0073e9SAndroid Build Coastguard Worker        rando = random.randint(0, 1)
9111*da0073e9SAndroid Build Coastguard Worker
9112*da0073e9SAndroid Build Coastguard Worker        def udo_fn_1(a, b):
9113*da0073e9SAndroid Build Coastguard Worker            return a * rando + b * rando
9114*da0073e9SAndroid Build Coastguard Worker
9115*da0073e9SAndroid Build Coastguard Worker        seen = []
9116*da0073e9SAndroid Build Coastguard Worker
9117*da0073e9SAndroid Build Coastguard Worker        def udo_fn_2(a, b):
9118*da0073e9SAndroid Build Coastguard Worker            seen.append(a)
9119*da0073e9SAndroid Build Coastguard Worker            seen.append(b)
9120*da0073e9SAndroid Build Coastguard Worker            return a * len(seen)
9121*da0073e9SAndroid Build Coastguard Worker
9122*da0073e9SAndroid Build Coastguard Worker        for udo_fn in [udo_fn_0, udo_fn_1, udo_fn_2]:
9123*da0073e9SAndroid Build Coastguard Worker            counters.clear()
9124*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
9125*da0073e9SAndroid Build Coastguard Worker
9126*da0073e9SAndroid Build Coastguard Worker            def fn(a, b, c, d, x):
9127*da0073e9SAndroid Build Coastguard Worker                l = [a, b, c, d, x]
9128*da0073e9SAndroid Build Coastguard Worker                for i, t in enumerate(l):
9129*da0073e9SAndroid Build Coastguard Worker                    l[i] = t * x
9130*da0073e9SAndroid Build Coastguard Worker                return itertools.accumulate(l, udo_fn)
9131*da0073e9SAndroid Build Coastguard Worker
9132*da0073e9SAndroid Build Coastguard Worker            t_list = [torch.tensor([i]) for i in range(4)]
9133*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[1, 2], [3, 4]])
9134*da0073e9SAndroid Build Coastguard Worker            eager = fn(*t_list, x)
9135*da0073e9SAndroid Build Coastguard Worker
9136*da0073e9SAndroid Build Coastguard Worker            compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9137*da0073e9SAndroid Build Coastguard Worker            compiled = compiled_fn(*t_list, x)
9138*da0073e9SAndroid Build Coastguard Worker
9139*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(eager), list(compiled))
9140*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(counters["graph_break"]), 0)
9141*da0073e9SAndroid Build Coastguard Worker
9142*da0073e9SAndroid Build Coastguard Worker    def test_pure_python_accumulate(self):
9143*da0073e9SAndroid Build Coastguard Worker        def accumulate(iterable, func=lambda x, y: x + y):
9144*da0073e9SAndroid Build Coastguard Worker            it = iter(iterable)
9145*da0073e9SAndroid Build Coastguard Worker            try:
9146*da0073e9SAndroid Build Coastguard Worker                # Initialize the accumulator with the first value from the iterable
9147*da0073e9SAndroid Build Coastguard Worker                accumulator = next(it)
9148*da0073e9SAndroid Build Coastguard Worker            except StopIteration:
9149*da0073e9SAndroid Build Coastguard Worker                # If the iterable is empty, return an empty generator
9150*da0073e9SAndroid Build Coastguard Worker                return
9151*da0073e9SAndroid Build Coastguard Worker            yield accumulator
9152*da0073e9SAndroid Build Coastguard Worker
9153*da0073e9SAndroid Build Coastguard Worker            for element in it:
9154*da0073e9SAndroid Build Coastguard Worker                accumulator = func(accumulator, element)
9155*da0073e9SAndroid Build Coastguard Worker                yield accumulator
9156*da0073e9SAndroid Build Coastguard Worker
9157*da0073e9SAndroid Build Coastguard Worker        def fn(it):
9158*da0073e9SAndroid Build Coastguard Worker            return accumulate(it)
9159*da0073e9SAndroid Build Coastguard Worker
9160*da0073e9SAndroid Build Coastguard Worker        t_list = [torch.tensor([i]) for i in range(4)]
9161*da0073e9SAndroid Build Coastguard Worker        eager = fn(t_list)
9162*da0073e9SAndroid Build Coastguard Worker
9163*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
9164*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(counter)(fn)
9165*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(t_list)
9166*da0073e9SAndroid Build Coastguard Worker
9167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(eager), list(compiled))
9168*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
9169*da0073e9SAndroid Build Coastguard Worker
9170*da0073e9SAndroid Build Coastguard Worker    def test_itertools_groupby_pure_python_default_identify_func(self):
9171*da0073e9SAndroid Build Coastguard Worker        counters.clear()
9172*da0073e9SAndroid Build Coastguard Worker
9173*da0073e9SAndroid Build Coastguard Worker        def fn(l):
9174*da0073e9SAndroid Build Coastguard Worker            return [(k, list(g)) for k, g in itertools.groupby(l)]
9175*da0073e9SAndroid Build Coastguard Worker
9176*da0073e9SAndroid Build Coastguard Worker        l = [1, 2, 2, 3, 4, 4, 4, 1, 2]
9177*da0073e9SAndroid Build Coastguard Worker        eager = fn(l)
9178*da0073e9SAndroid Build Coastguard Worker
9179*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9180*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(l)
9181*da0073e9SAndroid Build Coastguard Worker
9182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
9183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 0)
9184*da0073e9SAndroid Build Coastguard Worker
9185*da0073e9SAndroid Build Coastguard Worker    def test_itertools_groupby_pure_python_key_func(self):
9186*da0073e9SAndroid Build Coastguard Worker        counters.clear()
9187*da0073e9SAndroid Build Coastguard Worker
9188*da0073e9SAndroid Build Coastguard Worker        def fn(l):
9189*da0073e9SAndroid Build Coastguard Worker            return [(k, list(g)) for k, g in itertools.groupby(l, key=operator.neg)]
9190*da0073e9SAndroid Build Coastguard Worker
9191*da0073e9SAndroid Build Coastguard Worker        l = [1, 2, -2, 3, 4, 4, -4, 0, -2]
9192*da0073e9SAndroid Build Coastguard Worker        eager = fn(l)
9193*da0073e9SAndroid Build Coastguard Worker
9194*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9195*da0073e9SAndroid Build Coastguard Worker        compiled = compiled_fn(l)
9196*da0073e9SAndroid Build Coastguard Worker
9197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager, compiled)
9198*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(counters["graph_break"]), 0)
9199*da0073e9SAndroid Build Coastguard Worker
9200*da0073e9SAndroid Build Coastguard Worker    def test_list_iterator_contains(self):
9201*da0073e9SAndroid Build Coastguard Worker        def fn(x):
9202*da0073e9SAndroid Build Coastguard Worker            it = iter(["my_weight", "not_my_weight"])
9203*da0073e9SAndroid Build Coastguard Worker            next(it)
9204*da0073e9SAndroid Build Coastguard Worker            if "my_weight" in it:
9205*da0073e9SAndroid Build Coastguard Worker                return x + 2
9206*da0073e9SAndroid Build Coastguard Worker            return x + 1
9207*da0073e9SAndroid Build Coastguard Worker
9208*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(3)
9209*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
9210*da0073e9SAndroid Build Coastguard Worker
9211*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), compiled_fn(x))
9212*da0073e9SAndroid Build Coastguard Worker
9213*da0073e9SAndroid Build Coastguard Worker    def test_storage_return(self):
9214*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
9215*da0073e9SAndroid Build Coastguard Worker        def fn(x):
9216*da0073e9SAndroid Build Coastguard Worker            y = torch.sin(x + 1)
9217*da0073e9SAndroid Build Coastguard Worker            storage = x.untyped_storage()
9218*da0073e9SAndroid Build Coastguard Worker            storage.resize_(0)
9219*da0073e9SAndroid Build Coastguard Worker            y = torch.cos(y)
9220*da0073e9SAndroid Build Coastguard Worker            return y, storage
9221*da0073e9SAndroid Build Coastguard Worker
9222*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
9223*da0073e9SAndroid Build Coastguard Worker        expected = torch.cos(torch.sin(x + 1))
9224*da0073e9SAndroid Build Coastguard Worker        y, s = fn(x)
9225*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, expected)
9226*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.untyped_storage().size(), 0)
9227*da0073e9SAndroid Build Coastguard Worker        self.assertIs(s, x.untyped_storage())
9228*da0073e9SAndroid Build Coastguard Worker
9229*da0073e9SAndroid Build Coastguard Worker    def test_flat_name_to_original_fqn(self):
9230*da0073e9SAndroid Build Coastguard Worker        class FooBarModule(torch.nn.Module):
9231*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
9232*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9233*da0073e9SAndroid Build Coastguard Worker                self.register_parameter("0", torch.nn.Parameter(torch.randn(3, 4)))
9234*da0073e9SAndroid Build Coastguard Worker                self.register_buffer("test_buf", torch.randn(3, 4))
9235*da0073e9SAndroid Build Coastguard Worker                self.register_parameter(
9236*da0073e9SAndroid Build Coastguard Worker                    "test_param", torch.nn.Parameter(torch.randn(3, 4))
9237*da0073e9SAndroid Build Coastguard Worker                )
9238*da0073e9SAndroid Build Coastguard Worker
9239*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9240*da0073e9SAndroid Build Coastguard Worker                return ((x + self.test_buf) * getattr(self, "0")) / self.test_param
9241*da0073e9SAndroid Build Coastguard Worker
9242*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
9243*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
9244*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9245*da0073e9SAndroid Build Coastguard Worker                self.foo_bar = FooBarModule()
9246*da0073e9SAndroid Build Coastguard Worker                self.register_parameter(
9247*da0073e9SAndroid Build Coastguard Worker                    "test_param", torch.nn.Parameter(torch.randn(3, 4))
9248*da0073e9SAndroid Build Coastguard Worker                )
9249*da0073e9SAndroid Build Coastguard Worker                self.register_buffer("test_buf", torch.randn(3, 4))
9250*da0073e9SAndroid Build Coastguard Worker
9251*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9252*da0073e9SAndroid Build Coastguard Worker                return (self.foo_bar(x) + self.test_param) * self.test_buf
9253*da0073e9SAndroid Build Coastguard Worker
9254*da0073e9SAndroid Build Coastguard Worker        gm, _ = torch._dynamo.export(TestModule(), torch.randn(3, 4))
9255*da0073e9SAndroid Build Coastguard Worker        self.assertIn("dynamo_flat_name_to_original_fqn", gm.meta)
9256*da0073e9SAndroid Build Coastguard Worker        expected_fqn = {
9257*da0073e9SAndroid Build Coastguard Worker            "L__self___test_param": "test_param",
9258*da0073e9SAndroid Build Coastguard Worker            "L__self___test_buf": "test_buf",
9259*da0073e9SAndroid Build Coastguard Worker            "getattr_L__self___foo_bar___0__": "foo_bar.0",
9260*da0073e9SAndroid Build Coastguard Worker            "L__self___foo_bar_test_param": "foo_bar.test_param",
9261*da0073e9SAndroid Build Coastguard Worker            "L__self___foo_bar_test_buf": "foo_bar.test_buf",
9262*da0073e9SAndroid Build Coastguard Worker        }
9263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_fqn, gm.meta["dynamo_flat_name_to_original_fqn"])
9264*da0073e9SAndroid Build Coastguard Worker
9265*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_no_recording(self):
9266*da0073e9SAndroid Build Coastguard Worker        main = ShapeEnv(should_record_events=False)
9267*da0073e9SAndroid Build Coastguard Worker
9268*da0073e9SAndroid Build Coastguard Worker        # The main ShapeEnv should have no event recorded.
9269*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(main.events), 0)
9270*da0073e9SAndroid Build Coastguard Worker
9271*da0073e9SAndroid Build Coastguard Worker        # Call create_symbolic_sizes_strides_storage_offset on both of them.
9272*da0073e9SAndroid Build Coastguard Worker        r = main.create_symbolic_sizes_strides_storage_offset(
9273*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 2), ConstantSource("x")
9274*da0073e9SAndroid Build Coastguard Worker        )
9275*da0073e9SAndroid Build Coastguard Worker
9276*da0073e9SAndroid Build Coastguard Worker        # Create a guard: size[0] == 3 (call evaluate_expr)
9277*da0073e9SAndroid Build Coastguard Worker        #   - +1 guard entry
9278*da0073e9SAndroid Build Coastguard Worker        #   - +1 replacement entry
9279*da0073e9SAndroid Build Coastguard Worker        size = r[0]
9280*da0073e9SAndroid Build Coastguard Worker        bool(size[0] == 3)
9281*da0073e9SAndroid Build Coastguard Worker
9282*da0073e9SAndroid Build Coastguard Worker        # The main ShapeEnv should remain with no event recorded.
9283*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(main.events), 0)
9284*da0073e9SAndroid Build Coastguard Worker
9285*da0073e9SAndroid Build Coastguard Worker        if torch.fx.experimental.validator.translation_validation_enabled():
9286*da0073e9SAndroid Build Coastguard Worker            from torch.fx.experimental.symbolic_shapes import (
9287*da0073e9SAndroid Build Coastguard Worker                CURRENT_NODE_KEY,
9288*da0073e9SAndroid Build Coastguard Worker                SHAPEENV_EVENT_KEY,
9289*da0073e9SAndroid Build Coastguard Worker            )
9290*da0073e9SAndroid Build Coastguard Worker
9291*da0073e9SAndroid Build Coastguard Worker            # Check that we don't store any recording metadata on nodes
9292*da0073e9SAndroid Build Coastguard Worker            # from the symbolic shape FX graph.
9293*da0073e9SAndroid Build Coastguard Worker            for n in main.graph.nodes:
9294*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(SHAPEENV_EVENT_KEY in n.meta)
9295*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(CURRENT_NODE_KEY in n.meta)
9296*da0073e9SAndroid Build Coastguard Worker
9297*da0073e9SAndroid Build Coastguard Worker    def _replay_and_check(self, shape_env: ShapeEnv):
9298*da0073e9SAndroid Build Coastguard Worker        if shape_env.should_record_events:
9299*da0073e9SAndroid Build Coastguard Worker            replayed = replay_shape_env_events(shape_env.events)
9300*da0073e9SAndroid Build Coastguard Worker            shape_env.check_equal(replayed)
9301*da0073e9SAndroid Build Coastguard Worker
9302*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_equal_empty(self):
9303*da0073e9SAndroid Build Coastguard Worker        main, other = ShapeEnv(), ShapeEnv()
9304*da0073e9SAndroid Build Coastguard Worker        main.check_equal(other)
9305*da0073e9SAndroid Build Coastguard Worker        self._replay_and_check(main)
9306*da0073e9SAndroid Build Coastguard Worker
9307*da0073e9SAndroid Build Coastguard Worker    @onlyIfTranslationValidation
9308*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_equal_constructor(self):
9309*da0073e9SAndroid Build Coastguard Worker        main, other = ShapeEnv(allow_scalar_outputs=False), ShapeEnv()
9310*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedRaisesInline(
9311*da0073e9SAndroid Build Coastguard Worker            NotEqualError,
9312*da0073e9SAndroid Build Coastguard Worker            lambda: main.check_equal(other),
9313*da0073e9SAndroid Build Coastguard Worker            """\
9314*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match:
9315*da0073e9SAndroid Build Coastguard Worker
9316*da0073e9SAndroid Build Coastguard Worker==> settings: values don't match.
9317*da0073e9SAndroid Build Coastguard Worker  >  Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False)
9318*da0073e9SAndroid Build Coastguard Worker  > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, _allow_complex_guards_as_runtime_asserts=False)
9319*da0073e9SAndroid Build Coastguard Worker""",
9320*da0073e9SAndroid Build Coastguard Worker        )
9321*da0073e9SAndroid Build Coastguard Worker        self._replay_and_check(main)
9322*da0073e9SAndroid Build Coastguard Worker
9323*da0073e9SAndroid Build Coastguard Worker    @onlyIfTranslationValidation
9324*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self):
9325*da0073e9SAndroid Build Coastguard Worker        main, other = ShapeEnv(), ShapeEnv()
9326*da0073e9SAndroid Build Coastguard Worker        main.create_symbolic_sizes_strides_storage_offset(
9327*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 2), ConstantSource("x")
9328*da0073e9SAndroid Build Coastguard Worker        )
9329*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedRaisesInline(
9330*da0073e9SAndroid Build Coastguard Worker            NotEqualError,
9331*da0073e9SAndroid Build Coastguard Worker            lambda: main.check_equal(other),
9332*da0073e9SAndroid Build Coastguard Worker            """\
9333*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match:
9334*da0073e9SAndroid Build Coastguard Worker
9335*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match.
9336*da0073e9SAndroid Build Coastguard Worker  >  Left: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9337*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9338*da0073e9SAndroid Build Coastguard Worker==> source_to_symbol: values don't match.
9339*da0073e9SAndroid Build Coastguard Worker  >  Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]}
9340*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9341*da0073e9SAndroid Build Coastguard Worker==> val_to_var: values don't match.
9342*da0073e9SAndroid Build Coastguard Worker  >  Left: {0: 0, 1: 1, 2: s1, 3: s0}
9343*da0073e9SAndroid Build Coastguard Worker  > Right: {0: 0, 1: 1}
9344*da0073e9SAndroid Build Coastguard Worker==> var_to_range: values don't match.
9345*da0073e9SAndroid Build Coastguard Worker  >  Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
9346*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9347*da0073e9SAndroid Build Coastguard Worker==> var_to_sources: values don't match.
9348*da0073e9SAndroid Build Coastguard Worker  >  Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
9349*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9350*da0073e9SAndroid Build Coastguard Worker==> var_to_val: values don't match.
9351*da0073e9SAndroid Build Coastguard Worker  >  Left: {s0: 3, s1: 2}
9352*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9353*da0073e9SAndroid Build Coastguard Worker""",
9354*da0073e9SAndroid Build Coastguard Worker        )
9355*da0073e9SAndroid Build Coastguard Worker        self._replay_and_check(main)
9356*da0073e9SAndroid Build Coastguard Worker
9357*da0073e9SAndroid Build Coastguard Worker    @onlyIfTranslationValidation
9358*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_equal_unbacked(self):
9359*da0073e9SAndroid Build Coastguard Worker        main, other = ShapeEnv(), ShapeEnv()
9360*da0073e9SAndroid Build Coastguard Worker        main.create_unbacked_symint()
9361*da0073e9SAndroid Build Coastguard Worker        main.create_unbacked_symfloat()
9362*da0073e9SAndroid Build Coastguard Worker        main.create_unbacked_symbool()
9363*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedRaisesInline(
9364*da0073e9SAndroid Build Coastguard Worker            NotEqualError,
9365*da0073e9SAndroid Build Coastguard Worker            lambda: main.check_equal(other),
9366*da0073e9SAndroid Build Coastguard Worker            """\
9367*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match:
9368*da0073e9SAndroid Build Coastguard Worker
9369*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match.
9370*da0073e9SAndroid Build Coastguard Worker  >  Left: {u0, u1, zuf0}
9371*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9372*da0073e9SAndroid Build Coastguard Worker==> unbacked_symfloat_counter: values don't match.
9373*da0073e9SAndroid Build Coastguard Worker  >  Left: 1
9374*da0073e9SAndroid Build Coastguard Worker  > Right: 0
9375*da0073e9SAndroid Build Coastguard Worker==> unbacked_symint_counter: values don't match.
9376*da0073e9SAndroid Build Coastguard Worker  >  Left: 2
9377*da0073e9SAndroid Build Coastguard Worker  > Right: 0
9378*da0073e9SAndroid Build Coastguard Worker==> var_to_range: values don't match.
9379*da0073e9SAndroid Build Coastguard Worker  >  Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]}
9380*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9381*da0073e9SAndroid Build Coastguard Worker""",
9382*da0073e9SAndroid Build Coastguard Worker        )
9383*da0073e9SAndroid Build Coastguard Worker        self._replay_and_check(main)
9384*da0073e9SAndroid Build Coastguard Worker
9385*da0073e9SAndroid Build Coastguard Worker    @onlyIfTranslationValidation
9386*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_equal_evaluate_expr_divisible(self):
9387*da0073e9SAndroid Build Coastguard Worker        main, other = ShapeEnv(), ShapeEnv()
9388*da0073e9SAndroid Build Coastguard Worker
9389*da0073e9SAndroid Build Coastguard Worker        # Call create_symbolic_sizes_strides_storage_offset on both of them.
9390*da0073e9SAndroid Build Coastguard Worker        r = main.create_symbolic_sizes_strides_storage_offset(
9391*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 2), ConstantSource("x")
9392*da0073e9SAndroid Build Coastguard Worker        )
9393*da0073e9SAndroid Build Coastguard Worker        other.create_symbolic_sizes_strides_storage_offset(
9394*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 2), ConstantSource("x")
9395*da0073e9SAndroid Build Coastguard Worker        )
9396*da0073e9SAndroid Build Coastguard Worker
9397*da0073e9SAndroid Build Coastguard Worker        # Create a guard: size[0] % 3 == 0 (only in the main ShapeEnv)
9398*da0073e9SAndroid Build Coastguard Worker        #   - +1 guard entry
9399*da0073e9SAndroid Build Coastguard Worker        #   - +1 divisible entry
9400*da0073e9SAndroid Build Coastguard Worker        size = r[0]
9401*da0073e9SAndroid Build Coastguard Worker        bool(size[0] % 3 == 0)
9402*da0073e9SAndroid Build Coastguard Worker
9403*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedRaisesInline(
9404*da0073e9SAndroid Build Coastguard Worker            NotEqualError,
9405*da0073e9SAndroid Build Coastguard Worker            lambda: main.check_equal(other),
9406*da0073e9SAndroid Build Coastguard Worker            """\
9407*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match:
9408*da0073e9SAndroid Build Coastguard Worker
9409*da0073e9SAndroid Build Coastguard Worker==> divisible: values don't match.
9410*da0073e9SAndroid Build Coastguard Worker  >  Left: {Mod(s0, 3)}
9411*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9412*da0073e9SAndroid Build Coastguard Worker==> guards: values don't match.
9413*da0073e9SAndroid Build Coastguard Worker  >  Left: [Eq(Mod(s0, 3), 0)]
9414*da0073e9SAndroid Build Coastguard Worker  > Right: []
9415*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match.
9416*da0073e9SAndroid Build Coastguard Worker  >  Left: {_assert, eq, mod, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9417*da0073e9SAndroid Build Coastguard Worker  > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9418*da0073e9SAndroid Build Coastguard Worker""",
9419*da0073e9SAndroid Build Coastguard Worker        )
9420*da0073e9SAndroid Build Coastguard Worker        self._replay_and_check(main)
9421*da0073e9SAndroid Build Coastguard Worker
9422*da0073e9SAndroid Build Coastguard Worker    @onlyIfTranslationValidation
9423*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_equal_evaluate_expr_replacement(self):
9424*da0073e9SAndroid Build Coastguard Worker        main, other = ShapeEnv(), ShapeEnv()
9425*da0073e9SAndroid Build Coastguard Worker
9426*da0073e9SAndroid Build Coastguard Worker        # Call create_symbolic_sizes_strides_storage_offset on both of them.
9427*da0073e9SAndroid Build Coastguard Worker        r = main.create_symbolic_sizes_strides_storage_offset(
9428*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 2), ConstantSource("x")
9429*da0073e9SAndroid Build Coastguard Worker        )
9430*da0073e9SAndroid Build Coastguard Worker        other.create_symbolic_sizes_strides_storage_offset(
9431*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 2), ConstantSource("x")
9432*da0073e9SAndroid Build Coastguard Worker        )
9433*da0073e9SAndroid Build Coastguard Worker
9434*da0073e9SAndroid Build Coastguard Worker        # Create a guard: size[0] == 3 (only in the main ShapeEnv)
9435*da0073e9SAndroid Build Coastguard Worker        #   - +1 guard entry
9436*da0073e9SAndroid Build Coastguard Worker        #   - +1 replacement entry
9437*da0073e9SAndroid Build Coastguard Worker        size = r[0]
9438*da0073e9SAndroid Build Coastguard Worker        bool(size[0] == 3)
9439*da0073e9SAndroid Build Coastguard Worker
9440*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedRaisesInline(
9441*da0073e9SAndroid Build Coastguard Worker            NotEqualError,
9442*da0073e9SAndroid Build Coastguard Worker            lambda: main.check_equal(other),
9443*da0073e9SAndroid Build Coastguard Worker            """\
9444*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match:
9445*da0073e9SAndroid Build Coastguard Worker
9446*da0073e9SAndroid Build Coastguard Worker==> guards: values don't match.
9447*da0073e9SAndroid Build Coastguard Worker  >  Left: [Eq(s0, 3)]
9448*da0073e9SAndroid Build Coastguard Worker  > Right: []
9449*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match.
9450*da0073e9SAndroid Build Coastguard Worker  >  Left: {_assert, eq, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9451*da0073e9SAndroid Build Coastguard Worker  > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9452*da0073e9SAndroid Build Coastguard Worker==> replacements: values don't match.
9453*da0073e9SAndroid Build Coastguard Worker  >  Left: {s0: 3}
9454*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9455*da0073e9SAndroid Build Coastguard Worker==> var_to_range: values don't match.
9456*da0073e9SAndroid Build Coastguard Worker  >  Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
9457*da0073e9SAndroid Build Coastguard Worker  > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
9458*da0073e9SAndroid Build Coastguard Worker""",
9459*da0073e9SAndroid Build Coastguard Worker        )
9460*da0073e9SAndroid Build Coastguard Worker        self._replay_and_check(main)
9461*da0073e9SAndroid Build Coastguard Worker
9462*da0073e9SAndroid Build Coastguard Worker    @onlyIfTranslationValidation
9463*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_equal_evaluate_expr_refinement(self):
9464*da0073e9SAndroid Build Coastguard Worker        main, other = ShapeEnv(), ShapeEnv()
9465*da0073e9SAndroid Build Coastguard Worker
9466*da0073e9SAndroid Build Coastguard Worker        # Call create_symbolic_sizes_strides_storage_offset on both of them.
9467*da0073e9SAndroid Build Coastguard Worker        r = main.create_symbolic_sizes_strides_storage_offset(
9468*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 2), ConstantSource("x")
9469*da0073e9SAndroid Build Coastguard Worker        )
9470*da0073e9SAndroid Build Coastguard Worker        other.create_symbolic_sizes_strides_storage_offset(
9471*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 2), ConstantSource("x")
9472*da0073e9SAndroid Build Coastguard Worker        )
9473*da0073e9SAndroid Build Coastguard Worker
9474*da0073e9SAndroid Build Coastguard Worker        # Create a guard: size[0] >= 3 (only in the main ShapeEnv)
9475*da0073e9SAndroid Build Coastguard Worker        #   - +1 guard entry
9476*da0073e9SAndroid Build Coastguard Worker        #   - +1 var_to_guard entry
9477*da0073e9SAndroid Build Coastguard Worker        #   - Change: var_to_range
9478*da0073e9SAndroid Build Coastguard Worker        size = r[0]
9479*da0073e9SAndroid Build Coastguard Worker        bool(size[0] >= 3)
9480*da0073e9SAndroid Build Coastguard Worker
9481*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedRaisesInline(
9482*da0073e9SAndroid Build Coastguard Worker            NotEqualError,
9483*da0073e9SAndroid Build Coastguard Worker            lambda: main.check_equal(other),
9484*da0073e9SAndroid Build Coastguard Worker            """\
9485*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match:
9486*da0073e9SAndroid Build Coastguard Worker
9487*da0073e9SAndroid Build Coastguard Worker==> guards: values don't match.
9488*da0073e9SAndroid Build Coastguard Worker  >  Left: [s0 >= 3]
9489*da0073e9SAndroid Build Coastguard Worker  > Right: []
9490*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match.
9491*da0073e9SAndroid Build Coastguard Worker  >  Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9492*da0073e9SAndroid Build Coastguard Worker  > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
9493*da0073e9SAndroid Build Coastguard Worker==> var_to_range: values don't match.
9494*da0073e9SAndroid Build Coastguard Worker  >  Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
9495*da0073e9SAndroid Build Coastguard Worker  > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
9496*da0073e9SAndroid Build Coastguard Worker""",
9497*da0073e9SAndroid Build Coastguard Worker        )
9498*da0073e9SAndroid Build Coastguard Worker        self._replay_and_check(main)
9499*da0073e9SAndroid Build Coastguard Worker
9500*da0073e9SAndroid Build Coastguard Worker    @onlyIfTranslationValidation
9501*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_equal_runtime_assert(self):
9502*da0073e9SAndroid Build Coastguard Worker        main, other = ShapeEnv(), ShapeEnv()
9503*da0073e9SAndroid Build Coastguard Worker
9504*da0073e9SAndroid Build Coastguard Worker        # Call create_unbacked_symint on both of them.
9505*da0073e9SAndroid Build Coastguard Worker        r = main.create_unbacked_symint()
9506*da0073e9SAndroid Build Coastguard Worker        other.create_unbacked_symint()
9507*da0073e9SAndroid Build Coastguard Worker
9508*da0073e9SAndroid Build Coastguard Worker        # Create a runtime assert: r % 3 == 0 (only in the main ShapeEnv)
9509*da0073e9SAndroid Build Coastguard Worker        #   - +1 deferred_runtime_asserts entry
9510*da0073e9SAndroid Build Coastguard Worker        #   - Change: num_deferred_runtime_asserts
9511*da0073e9SAndroid Build Coastguard Worker        expect_true(r % 3 == 0)
9512*da0073e9SAndroid Build Coastguard Worker
9513*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedRaisesInline(
9514*da0073e9SAndroid Build Coastguard Worker            NotEqualError,
9515*da0073e9SAndroid Build Coastguard Worker            lambda: main.check_equal(other),
9516*da0073e9SAndroid Build Coastguard Worker            """\
9517*da0073e9SAndroid Build Coastguard WorkerShapeEnv not equal: field values don't match:
9518*da0073e9SAndroid Build Coastguard Worker
9519*da0073e9SAndroid Build Coastguard Worker==> deferred_runtime_asserts: values don't match.
9520*da0073e9SAndroid Build Coastguard Worker  >  Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
9521*da0073e9SAndroid Build Coastguard Worker  > Right: {}
9522*da0073e9SAndroid Build Coastguard Worker==> name_to_node: values don't match.
9523*da0073e9SAndroid Build Coastguard Worker  >  Left: {_assert, eq, mod, u0}
9524*da0073e9SAndroid Build Coastguard Worker  > Right: {u0}
9525*da0073e9SAndroid Build Coastguard Worker==> num_deferred_runtime_asserts: values don't match.
9526*da0073e9SAndroid Build Coastguard Worker  >  Left: 1
9527*da0073e9SAndroid Build Coastguard Worker  > Right: 0
9528*da0073e9SAndroid Build Coastguard Worker""",
9529*da0073e9SAndroid Build Coastguard Worker        )
9530*da0073e9SAndroid Build Coastguard Worker        self._replay_and_check(main)
9531*da0073e9SAndroid Build Coastguard Worker
9532*da0073e9SAndroid Build Coastguard Worker    def test_shape_env_recorded_function_fallback(self):
9533*da0073e9SAndroid Build Coastguard Worker        # Make sure the record/replay mechanism for ShapeEnv will fallback
9534*da0073e9SAndroid Build Coastguard Worker        # if no ShapeEnv instance is found.
9535*da0073e9SAndroid Build Coastguard Worker        constrain_range(5, min=2, max=10)
9536*da0073e9SAndroid Build Coastguard Worker        constrain_unify(5, 5)
9537*da0073e9SAndroid Build Coastguard Worker
9538*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedRaisesInline(
9539*da0073e9SAndroid Build Coastguard Worker            AssertionError,
9540*da0073e9SAndroid Build Coastguard Worker            lambda: _constrain_range_for_size(5, min=2, max=10),
9541*da0073e9SAndroid Build Coastguard Worker            """can only constrain range for SymInt""",
9542*da0073e9SAndroid Build Coastguard Worker        )
9543*da0073e9SAndroid Build Coastguard Worker
9544*da0073e9SAndroid Build Coastguard Worker    def test_default_dtype_change(self):
9545*da0073e9SAndroid Build Coastguard Worker        @torch.compile
9546*da0073e9SAndroid Build Coastguard Worker        def foo():
9547*da0073e9SAndroid Build Coastguard Worker            def inner(a, b, res_dtype):
9548*da0073e9SAndroid Build Coastguard Worker                print(a, b, res_dtype)
9549*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.result_type(a, b), res_dtype)
9550*da0073e9SAndroid Build Coastguard Worker
9551*da0073e9SAndroid Build Coastguard Worker            inner(torch.tensor(1, device="cpu"), 1.0, torch.get_default_dtype())
9552*da0073e9SAndroid Build Coastguard Worker
9553*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
9554*da0073e9SAndroid Build Coastguard Worker            foo()
9555*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
9556*da0073e9SAndroid Build Coastguard Worker            foo()
9557*da0073e9SAndroid Build Coastguard Worker
9558*da0073e9SAndroid Build Coastguard Worker    def test_numpy_ufunc_out(self):
9559*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
9560*da0073e9SAndroid Build Coastguard Worker        def foo():
9561*da0073e9SAndroid Build Coastguard Worker            x = np.arange(5)
9562*da0073e9SAndroid Build Coastguard Worker            out = np.empty((x.shape[0], x.shape[0]))
9563*da0073e9SAndroid Build Coastguard Worker            res_out = np.sin(x, out=out)
9564*da0073e9SAndroid Build Coastguard Worker            assert res_out is out
9565*da0073e9SAndroid Build Coastguard Worker
9566*da0073e9SAndroid Build Coastguard Worker        foo()
9567*da0073e9SAndroid Build Coastguard Worker
9568*da0073e9SAndroid Build Coastguard Worker    # Unfortunately, we don't currently preserve the ids of
9569*da0073e9SAndroid Build Coastguard Worker    # res_out and out correctly across the graph break
9570*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
9571*da0073e9SAndroid Build Coastguard Worker    def test_numpy_ufunc_out_graph_break(self):
9572*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
9573*da0073e9SAndroid Build Coastguard Worker        def foo():
9574*da0073e9SAndroid Build Coastguard Worker            x = np.arange(5)
9575*da0073e9SAndroid Build Coastguard Worker            out = np.empty((x.shape[0], x.shape[0]))
9576*da0073e9SAndroid Build Coastguard Worker            res_out = np.sin(x, out=out)
9577*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
9578*da0073e9SAndroid Build Coastguard Worker            assert res_out is out
9579*da0073e9SAndroid Build Coastguard Worker
9580*da0073e9SAndroid Build Coastguard Worker        foo()
9581*da0073e9SAndroid Build Coastguard Worker
9582*da0073e9SAndroid Build Coastguard Worker    def test_dict_subclass_cannot_be_initialized_in_graph(self):
9583*da0073e9SAndroid Build Coastguard Worker        for super_class in (
9584*da0073e9SAndroid Build Coastguard Worker            collections.OrderedDict,
9585*da0073e9SAndroid Build Coastguard Worker            dict,
9586*da0073e9SAndroid Build Coastguard Worker        ):
9587*da0073e9SAndroid Build Coastguard Worker
9588*da0073e9SAndroid Build Coastguard Worker            class CustomDict(super_class):
9589*da0073e9SAndroid Build Coastguard Worker                def __init__(self, *args, **kwargs):
9590*da0073e9SAndroid Build Coastguard Worker                    super().__init__(*args, **kwargs)
9591*da0073e9SAndroid Build Coastguard Worker
9592*da0073e9SAndroid Build Coastguard Worker            def fn(x):
9593*da0073e9SAndroid Build Coastguard Worker                c = CustomDict()
9594*da0073e9SAndroid Build Coastguard Worker                c["key"] = x
9595*da0073e9SAndroid Build Coastguard Worker                assert "key" in c
9596*da0073e9SAndroid Build Coastguard Worker                return c["key"] + 1
9597*da0073e9SAndroid Build Coastguard Worker
9598*da0073e9SAndroid Build Coastguard Worker            fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
9599*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
9600*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.exc.Unsupported, "call_function UserDefinedClassVariable"
9601*da0073e9SAndroid Build Coastguard Worker            ):
9602*da0073e9SAndroid Build Coastguard Worker                print(fn_opt(torch.zeros(1)))
9603*da0073e9SAndroid Build Coastguard Worker
9604*da0073e9SAndroid Build Coastguard Worker    @wrapDeterministicFlagAPITest
9605*da0073e9SAndroid Build Coastguard Worker    def test_backward_deterministic_mode_mismatch_warning(self):
9606*da0073e9SAndroid Build Coastguard Worker        @torch.compile
9607*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
9608*da0073e9SAndroid Build Coastguard Worker            return a + b
9609*da0073e9SAndroid Build Coastguard Worker
9610*da0073e9SAndroid Build Coastguard Worker        for forward_deterministic, backward_deterministic in itertools.product(
9611*da0073e9SAndroid Build Coastguard Worker            [True, False], [True, False]
9612*da0073e9SAndroid Build Coastguard Worker        ):
9613*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(forward_deterministic)
9614*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(10, requires_grad=True)
9615*da0073e9SAndroid Build Coastguard Worker            res = func(a, 1)
9616*da0073e9SAndroid Build Coastguard Worker            grad = torch.ones_like(res)
9617*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(backward_deterministic)
9618*da0073e9SAndroid Build Coastguard Worker
9619*da0073e9SAndroid Build Coastguard Worker            if not forward_deterministic and backward_deterministic:
9620*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
9621*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
9622*da0073e9SAndroid Build Coastguard Worker                    "^This compiled backward function is being run with torch\.use_deterministic_algorithms",
9623*da0073e9SAndroid Build Coastguard Worker                ):
9624*da0073e9SAndroid Build Coastguard Worker                    res.backward(grad)
9625*da0073e9SAndroid Build Coastguard Worker
9626*da0073e9SAndroid Build Coastguard Worker            else:
9627*da0073e9SAndroid Build Coastguard Worker                res.backward(grad)
9628*da0073e9SAndroid Build Coastguard Worker
9629*da0073e9SAndroid Build Coastguard Worker    def test_torch_dynamo_codegen_pow(self):
9630*da0073e9SAndroid Build Coastguard Worker        def pow(x):
9631*da0073e9SAndroid Build Coastguard Worker            return x**2
9632*da0073e9SAndroid Build Coastguard Worker
9633*da0073e9SAndroid Build Coastguard Worker        x = np.arange(8)
9634*da0073e9SAndroid Build Coastguard Worker        pow_opt = torch.compile(pow)
9635*da0073e9SAndroid Build Coastguard Worker
9636*da0073e9SAndroid Build Coastguard Worker        actual, source_code = run_and_get_code(pow_opt, x)
9637*da0073e9SAndroid Build Coastguard Worker        expect = pow(x)
9638*da0073e9SAndroid Build Coastguard Worker
9639*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, actual)
9640*da0073e9SAndroid Build Coastguard Worker
9641*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
9642*da0073e9SAndroid Build Coastguard Worker            all("aten.pow" not in code for code in source_code),
9643*da0073e9SAndroid Build Coastguard Worker            msg="Encountered an unexpected fallback to 'aten pow' in dynamo compiled code",
9644*da0073e9SAndroid Build Coastguard Worker        )
9645*da0073e9SAndroid Build Coastguard Worker
9646*da0073e9SAndroid Build Coastguard Worker    def test_graph_break_compilation_metrics(self):
9647*da0073e9SAndroid Build Coastguard Worker        def fn(x):
9648*da0073e9SAndroid Build Coastguard Worker            x.cos()
9649*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
9650*da0073e9SAndroid Build Coastguard Worker            x.sin()
9651*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
9652*da0073e9SAndroid Build Coastguard Worker            return x.cos()
9653*da0073e9SAndroid Build Coastguard Worker
9654*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.utils.clear_compilation_metrics()
9655*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((4, 4))
9656*da0073e9SAndroid Build Coastguard Worker        f = torch.compile(fn, backend="eager")
9657*da0073e9SAndroid Build Coastguard Worker        f(x)
9658*da0073e9SAndroid Build Coastguard Worker        metrics = torch._dynamo.utils.get_compilation_metrics()
9659*da0073e9SAndroid Build Coastguard Worker        # Should only be one restart per event
9660*da0073e9SAndroid Build Coastguard Worker        (restart_reason,) = metrics[0].restart_reasons
9661*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
9662*da0073e9SAndroid Build Coastguard Worker            "skip function graph_break" in restart_reason,
9663*da0073e9SAndroid Build Coastguard Worker            "Should have logged graph break reason",
9664*da0073e9SAndroid Build Coastguard Worker        )
9665*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
9666*da0073e9SAndroid Build Coastguard Worker            metrics[0].dynamo_time_before_restart_s
9667*da0073e9SAndroid Build Coastguard Worker            <= metrics[0].entire_frame_compile_time_s
9668*da0073e9SAndroid Build Coastguard Worker        )
9669*da0073e9SAndroid Build Coastguard Worker
9670*da0073e9SAndroid Build Coastguard Worker        (restart_reason,) = metrics[1].restart_reasons
9671*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
9672*da0073e9SAndroid Build Coastguard Worker            "skip function graph_break" in restart_reason,
9673*da0073e9SAndroid Build Coastguard Worker            "Should have logged graph break reason",
9674*da0073e9SAndroid Build Coastguard Worker        )
9675*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
9676*da0073e9SAndroid Build Coastguard Worker            metrics[1].dynamo_time_before_restart_s
9677*da0073e9SAndroid Build Coastguard Worker            <= metrics[1].entire_frame_compile_time_s
9678*da0073e9SAndroid Build Coastguard Worker        )
9679*da0073e9SAndroid Build Coastguard Worker
9680*da0073e9SAndroid Build Coastguard Worker        # No restarts
9681*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
9682*da0073e9SAndroid Build Coastguard Worker            len(metrics[2].restart_reasons) == 0, "Last compile has no graph break"
9683*da0073e9SAndroid Build Coastguard Worker        )
9684*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(metrics[2].dynamo_time_before_restart_s == 0)
9685*da0073e9SAndroid Build Coastguard Worker
9686*da0073e9SAndroid Build Coastguard Worker    def test_graph_break_compilation_metrics_on_failure(self):
9687*da0073e9SAndroid Build Coastguard Worker        def fn(x):
9688*da0073e9SAndroid Build Coastguard Worker            return x.sin()
9689*da0073e9SAndroid Build Coastguard Worker
9690*da0073e9SAndroid Build Coastguard Worker        def broken_backend(gm, example_inputs):
9691*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("broken backend")
9692*da0073e9SAndroid Build Coastguard Worker
9693*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((4, 4))
9694*da0073e9SAndroid Build Coastguard Worker        f = torch.compile(fn, backend=broken_backend)
9695*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.suppress_errors", True):
9696*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.utils.clear_compilation_metrics()
9697*da0073e9SAndroid Build Coastguard Worker            f(x)
9698*da0073e9SAndroid Build Coastguard Worker            metrics = torch._dynamo.utils.get_compilation_metrics()
9699*da0073e9SAndroid Build Coastguard Worker            for metric in metrics:
9700*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(metric.dynamo_time_before_restart_s > 0)
9701*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
9702*da0073e9SAndroid Build Coastguard Worker                    "RuntimeError: broken backend" in metric.fail_reason,
9703*da0073e9SAndroid Build Coastguard Worker                    "Should have logged fail reason",
9704*da0073e9SAndroid Build Coastguard Worker                )
9705*da0073e9SAndroid Build Coastguard Worker
9706*da0073e9SAndroid Build Coastguard Worker    def test_compilation_metrics_size_limit(self):
9707*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
9708*da0073e9SAndroid Build Coastguard Worker            return x.relu()
9709*da0073e9SAndroid Build Coastguard Worker
9710*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
9711*da0073e9SAndroid Build Coastguard Worker            return x.cos()
9712*da0073e9SAndroid Build Coastguard Worker
9713*da0073e9SAndroid Build Coastguard Worker        def fn3(x):
9714*da0073e9SAndroid Build Coastguard Worker            return x.sin()
9715*da0073e9SAndroid Build Coastguard Worker
9716*da0073e9SAndroid Build Coastguard Worker        def fn4(x):
9717*da0073e9SAndroid Build Coastguard Worker            return x.exp()
9718*da0073e9SAndroid Build Coastguard Worker
9719*da0073e9SAndroid Build Coastguard Worker        import contextlib
9720*da0073e9SAndroid Build Coastguard Worker
9721*da0073e9SAndroid Build Coastguard Worker        @contextlib.contextmanager
9722*da0073e9SAndroid Build Coastguard Worker        def metrics_limit_ctx():
9723*da0073e9SAndroid Build Coastguard Worker            try:
9724*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.utils.set_compilation_metrics_limit(3)
9725*da0073e9SAndroid Build Coastguard Worker                yield
9726*da0073e9SAndroid Build Coastguard Worker            finally:
9727*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.utils.set_compilation_metrics_limit(
9728*da0073e9SAndroid Build Coastguard Worker                    torch._dynamo.utils.DEFAULT_COMPILATION_METRICS_LIMIT
9729*da0073e9SAndroid Build Coastguard Worker                )
9730*da0073e9SAndroid Build Coastguard Worker
9731*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((4, 4))
9732*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
9733*da0073e9SAndroid Build Coastguard Worker        torch.compile(fn1, backend="eager")(x)
9734*da0073e9SAndroid Build Coastguard Worker        torch.compile(fn2, backend="eager")(x)
9735*da0073e9SAndroid Build Coastguard Worker        torch.compile(fn3, backend="eager")(x)
9736*da0073e9SAndroid Build Coastguard Worker        torch.compile(fn4, backend="eager")(x)
9737*da0073e9SAndroid Build Coastguard Worker
9738*da0073e9SAndroid Build Coastguard Worker        with metrics_limit_ctx():
9739*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.utils.clear_compilation_metrics()
9740*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
9741*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics()))
9742*da0073e9SAndroid Build Coastguard Worker            torch.compile(fn1, backend="eager")(x)
9743*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(1, len(torch._dynamo.utils.get_compilation_metrics()))
9744*da0073e9SAndroid Build Coastguard Worker            torch.compile(fn2, backend="eager")(x)
9745*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(2, len(torch._dynamo.utils.get_compilation_metrics()))
9746*da0073e9SAndroid Build Coastguard Worker            torch.compile(fn3, backend="eager")(x)
9747*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics()))
9748*da0073e9SAndroid Build Coastguard Worker            torch.compile(fn4, backend="eager")(x)
9749*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(3, len(torch._dynamo.utils.get_compilation_metrics()))
9750*da0073e9SAndroid Build Coastguard Worker
9751*da0073e9SAndroid Build Coastguard Worker    def test_funcname_cache(self):
9752*da0073e9SAndroid Build Coastguard Worker        src = """\
9753*da0073e9SAndroid Build Coastguard Workerimport torch
9754*da0073e9SAndroid Build Coastguard Workerif True:
9755*da0073e9SAndroid Build Coastguard Worker    test = 3
9756*da0073e9SAndroid Build Coastguard Worker
9757*da0073e9SAndroid Build Coastguard Workerclass AAA:
9758*da0073e9SAndroid Build Coastguard Worker    class DUMMY:
9759*da0073e9SAndroid Build Coastguard Worker        class DUMMY2:
9760*da0073e9SAndroid Build Coastguard Worker            pass
9761*da0073e9SAndroid Build Coastguard Worker
9762*da0073e9SAndroid Build Coastguard Worker    def dummy(self):
9763*da0073e9SAndroid Build Coastguard Worker        def dummy2():
9764*da0073e9SAndroid Build Coastguard Worker            pass
9765*da0073e9SAndroid Build Coastguard Worker    class BBB:
9766*da0073e9SAndroid Build Coastguard Worker        @staticmethod
9767*da0073e9SAndroid Build Coastguard Worker        def CCC():
9768*da0073e9SAndroid Build Coastguard Worker            class DDD:
9769*da0073e9SAndroid Build Coastguard Worker                if True:
9770*da0073e9SAndroid Build Coastguard Worker                    @staticmethod
9771*da0073e9SAndroid Build Coastguard Worker                    def EEE():
9772*da0073e9SAndroid Build Coastguard Worker                        x = [torch.ones(3, 3) for _ in range(5)]
9773*da0073e9SAndroid Build Coastguard Worker                        return x
9774*da0073e9SAndroid Build Coastguard Worker            return DDD
9775*da0073e9SAndroid Build Coastguard Workerdef fn():
9776*da0073e9SAndroid Build Coastguard Worker    return 3
9777*da0073e9SAndroid Build Coastguard Worker"""
9778*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile(mode="w") as f:
9779*da0073e9SAndroid Build Coastguard Worker            f.write(src)
9780*da0073e9SAndroid Build Coastguard Worker            f.flush()
9781*da0073e9SAndroid Build Coastguard Worker            from torch._dynamo.funcname_cache import get_funcname
9782*da0073e9SAndroid Build Coastguard Worker
9783*da0073e9SAndroid Build Coastguard Worker            names = [get_funcname(f.name, i + 1) for i in range(src.count("\n") + 1)]
9784*da0073e9SAndroid Build Coastguard Worker
9785*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
9786*da0073e9SAndroid Build Coastguard Worker            "\n".join(names),
9787*da0073e9SAndroid Build Coastguard Worker            """\
9788*da0073e9SAndroid Build Coastguard Worker
9789*da0073e9SAndroid Build Coastguard Worker
9790*da0073e9SAndroid Build Coastguard Worker
9791*da0073e9SAndroid Build Coastguard Worker
9792*da0073e9SAndroid Build Coastguard WorkerAAA
9793*da0073e9SAndroid Build Coastguard WorkerAAA.DUMMY
9794*da0073e9SAndroid Build Coastguard WorkerAAA.DUMMY.DUMMY2
9795*da0073e9SAndroid Build Coastguard WorkerAAA.DUMMY.DUMMY2
9796*da0073e9SAndroid Build Coastguard WorkerAAA.DUMMY.DUMMY2
9797*da0073e9SAndroid Build Coastguard WorkerAAA.dummy
9798*da0073e9SAndroid Build Coastguard WorkerAAA.dummy.dummy2
9799*da0073e9SAndroid Build Coastguard WorkerAAA.dummy.dummy2
9800*da0073e9SAndroid Build Coastguard WorkerAAA.BBB
9801*da0073e9SAndroid Build Coastguard WorkerAAA.BBB
9802*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC
9803*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD
9804*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD
9805*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD
9806*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD.EEE
9807*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD.EEE
9808*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC.DDD.EEE
9809*da0073e9SAndroid Build Coastguard WorkerAAA.BBB.CCC
9810*da0073e9SAndroid Build Coastguard Workerfn
9811*da0073e9SAndroid Build Coastguard Workerfn
9812*da0073e9SAndroid Build Coastguard Worker""",
9813*da0073e9SAndroid Build Coastguard Worker        )
9814*da0073e9SAndroid Build Coastguard Worker
9815*da0073e9SAndroid Build Coastguard Worker    def test_return_dict_with_graph_break_and_update(self):
9816*da0073e9SAndroid Build Coastguard Worker        def create():
9817*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
9818*da0073e9SAndroid Build Coastguard Worker            return {0: torch.tensor(3)}
9819*da0073e9SAndroid Build Coastguard Worker
9820*da0073e9SAndroid Build Coastguard Worker        def fn():
9821*da0073e9SAndroid Build Coastguard Worker            return {**create()}
9822*da0073e9SAndroid Build Coastguard Worker
9823*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager")(fn)
9824*da0073e9SAndroid Build Coastguard Worker        result = opt_fn()
9825*da0073e9SAndroid Build Coastguard Worker        self.assertIn(0, result)
9826*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(result[0], torch.tensor(3)))
9827*da0073e9SAndroid Build Coastguard Worker
9828*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_reset_clears_cache(self):
9829*da0073e9SAndroid Build Coastguard Worker        """Test that dynamo bytecode cache is freed
9830*da0073e9SAndroid Build Coastguard Worker        when dynamo reset is called
9831*da0073e9SAndroid Build Coastguard Worker        """
9832*da0073e9SAndroid Build Coastguard Worker
9833*da0073e9SAndroid Build Coastguard Worker        def fn(x):
9834*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
9835*da0073e9SAndroid Build Coastguard Worker
9836*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager")(fn)
9837*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(3, 3))
9838*da0073e9SAndroid Build Coastguard Worker
9839*da0073e9SAndroid Build Coastguard Worker        c1 = _debug_get_cache_entry_list(fn.__code__)
9840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(c1), 1)
9841*da0073e9SAndroid Build Coastguard Worker
9842*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
9843*da0073e9SAndroid Build Coastguard Worker        c2 = _debug_get_cache_entry_list(fn.__code__)
9844*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(c2), 0)
9845*da0073e9SAndroid Build Coastguard Worker
9846*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
9847*da0073e9SAndroid Build Coastguard Worker    def test_guard_size_oblivious(self):
9848*da0073e9SAndroid Build Coastguard Worker        # This code, in fact, does NOT work in eager
9849*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
9850*da0073e9SAndroid Build Coastguard Worker        def fn(x):
9851*da0073e9SAndroid Build Coastguard Worker            y = torch.zeros(x.item())
9852*da0073e9SAndroid Build Coastguard Worker            if guard_size_oblivious(y.size(0) == 0):
9853*da0073e9SAndroid Build Coastguard Worker                assert False
9854*da0073e9SAndroid Build Coastguard Worker            return y
9855*da0073e9SAndroid Build Coastguard Worker
9856*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(torch.tensor([0])), torch.zeros(0))
9857*da0073e9SAndroid Build Coastguard Worker
9858*da0073e9SAndroid Build Coastguard Worker    def test_guard_size_oblivious_backed(self):
9859*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
9860*da0073e9SAndroid Build Coastguard Worker        def f(x):
9861*da0073e9SAndroid Build Coastguard Worker            y = x.size(0)
9862*da0073e9SAndroid Build Coastguard Worker            # This doesn't actually do anything
9863*da0073e9SAndroid Build Coastguard Worker            if guard_size_oblivious(y == 0):
9864*da0073e9SAndroid Build Coastguard Worker                return torch.randn(1)
9865*da0073e9SAndroid Build Coastguard Worker            else:
9866*da0073e9SAndroid Build Coastguard Worker                return torch.randn(2)
9867*da0073e9SAndroid Build Coastguard Worker
9868*da0073e9SAndroid Build Coastguard Worker        # Should not fail in either case
9869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(torch.randn(0)).shape, (1,))
9870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(torch.randn(2)).shape, (2,))
9871*da0073e9SAndroid Build Coastguard Worker
9872*da0073e9SAndroid Build Coastguard Worker    def _test_compile_model_free(self, model_inp_ctr, weakref_watch):
9873*da0073e9SAndroid Build Coastguard Worker        """
9874*da0073e9SAndroid Build Coastguard Worker        Args:
9875*da0073e9SAndroid Build Coastguard Worker        model_inp_ctr
9876*da0073e9SAndroid Build Coastguard Worker            - constructor that returns a new model and inputs to that model
9877*da0073e9SAndroid Build Coastguard Worker        weakref_watch
9878*da0073e9SAndroid Build Coastguard Worker            - function that returns a layer of the model for weakref to
9879*da0073e9SAndroid Build Coastguard Worker              finalize on, so we can check that the layer is freed after
9880*da0073e9SAndroid Build Coastguard Worker              the model goes out of scope
9881*da0073e9SAndroid Build Coastguard Worker        """
9882*da0073e9SAndroid Build Coastguard Worker        cleared = False
9883*da0073e9SAndroid Build Coastguard Worker
9884*da0073e9SAndroid Build Coastguard Worker        def finalize():
9885*da0073e9SAndroid Build Coastguard Worker            nonlocal cleared
9886*da0073e9SAndroid Build Coastguard Worker            cleared = True
9887*da0073e9SAndroid Build Coastguard Worker
9888*da0073e9SAndroid Build Coastguard Worker        def run():
9889*da0073e9SAndroid Build Coastguard Worker            mod, inp = model_inp_ctr()
9890*da0073e9SAndroid Build Coastguard Worker            weakref.finalize(weakref_watch(mod), finalize)
9891*da0073e9SAndroid Build Coastguard Worker            torch.compile(mod, backend="eager")(inp)
9892*da0073e9SAndroid Build Coastguard Worker
9893*da0073e9SAndroid Build Coastguard Worker        run()
9894*da0073e9SAndroid Build Coastguard Worker        gc.collect()
9895*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cleared)
9896*da0073e9SAndroid Build Coastguard Worker
9897*da0073e9SAndroid Build Coastguard Worker    def test_custom_module_free(self):
9898*da0073e9SAndroid Build Coastguard Worker        """Test that a model is freed when it goes out of scope"""
9899*da0073e9SAndroid Build Coastguard Worker
9900*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
9901*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
9902*da0073e9SAndroid Build Coastguard Worker                super(Mod, self).__init__()
9903*da0073e9SAndroid Build Coastguard Worker                self.fc = torch.nn.Linear(100, 100)
9904*da0073e9SAndroid Build Coastguard Worker
9905*da0073e9SAndroid Build Coastguard Worker            def forward(self, out):
9906*da0073e9SAndroid Build Coastguard Worker                return self.fc(out)
9907*da0073e9SAndroid Build Coastguard Worker
9908*da0073e9SAndroid Build Coastguard Worker        self._test_compile_model_free(
9909*da0073e9SAndroid Build Coastguard Worker            lambda: (Mod(), torch.randn(100, 100)),
9910*da0073e9SAndroid Build Coastguard Worker            lambda mod: mod.fc,
9911*da0073e9SAndroid Build Coastguard Worker        )
9912*da0073e9SAndroid Build Coastguard Worker
9913*da0073e9SAndroid Build Coastguard Worker    def test_sequential_module_free(self):
9914*da0073e9SAndroid Build Coastguard Worker        self._test_compile_model_free(
9915*da0073e9SAndroid Build Coastguard Worker            lambda: (
9916*da0073e9SAndroid Build Coastguard Worker                torch.nn.Sequential(
9917*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Linear(100, 100),
9918*da0073e9SAndroid Build Coastguard Worker                    torch.nn.ReLU(),
9919*da0073e9SAndroid Build Coastguard Worker                ),
9920*da0073e9SAndroid Build Coastguard Worker                torch.randn(100, 100),
9921*da0073e9SAndroid Build Coastguard Worker            ),
9922*da0073e9SAndroid Build Coastguard Worker            lambda mod: mod[0],
9923*da0073e9SAndroid Build Coastguard Worker        )
9924*da0073e9SAndroid Build Coastguard Worker
9925*da0073e9SAndroid Build Coastguard Worker    def test_linear_module_free(self):
9926*da0073e9SAndroid Build Coastguard Worker        self._test_compile_model_free(
9927*da0073e9SAndroid Build Coastguard Worker            lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)),
9928*da0073e9SAndroid Build Coastguard Worker            lambda mod: mod,
9929*da0073e9SAndroid Build Coastguard Worker        )
9930*da0073e9SAndroid Build Coastguard Worker
9931*da0073e9SAndroid Build Coastguard Worker    # The following 2 tests fail due to https://github.com/python/cpython/issues/118013.
9932*da0073e9SAndroid Build Coastguard Worker    # Tracked by https://github.com/pytorch/pytorch/issues/124302.
9933*da0073e9SAndroid Build Coastguard Worker    # The xfails can be removed once Python 3.12 is updated on CI.
9934*da0073e9SAndroid Build Coastguard Worker    @xfailIfPy312
9935*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "Skipping this test for release/2.4")
9936*da0073e9SAndroid Build Coastguard Worker    def test_outside_linear_module_free(self):
9937*da0073e9SAndroid Build Coastguard Worker        # Compared to test_linear_module_free, the linear
9938*da0073e9SAndroid Build Coastguard Worker        # layer is not the code object that is directly compiled.
9939*da0073e9SAndroid Build Coastguard Worker
9940*da0073e9SAndroid Build Coastguard Worker        # This test does not use _test_compile_model_free because of difficulty
9941*da0073e9SAndroid Build Coastguard Worker        # in handling variable fc.
9942*da0073e9SAndroid Build Coastguard Worker
9943*da0073e9SAndroid Build Coastguard Worker        cleared = False
9944*da0073e9SAndroid Build Coastguard Worker
9945*da0073e9SAndroid Build Coastguard Worker        def finalize():
9946*da0073e9SAndroid Build Coastguard Worker            nonlocal cleared
9947*da0073e9SAndroid Build Coastguard Worker            cleared = True
9948*da0073e9SAndroid Build Coastguard Worker
9949*da0073e9SAndroid Build Coastguard Worker        def run():
9950*da0073e9SAndroid Build Coastguard Worker            fc = torch.nn.Linear(100, 100)
9951*da0073e9SAndroid Build Coastguard Worker
9952*da0073e9SAndroid Build Coastguard Worker            class Mod(torch.nn.Module):
9953*da0073e9SAndroid Build Coastguard Worker                def __init__(self):
9954*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
9955*da0073e9SAndroid Build Coastguard Worker                    self.fc_ref = fc
9956*da0073e9SAndroid Build Coastguard Worker
9957*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
9958*da0073e9SAndroid Build Coastguard Worker                    return self.fc_ref(x)
9959*da0073e9SAndroid Build Coastguard Worker
9960*da0073e9SAndroid Build Coastguard Worker            mod = Mod()
9961*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(100, 100)
9962*da0073e9SAndroid Build Coastguard Worker            weakref.finalize(fc, finalize)
9963*da0073e9SAndroid Build Coastguard Worker            torch.compile(mod, backend="eager")(inp)
9964*da0073e9SAndroid Build Coastguard Worker
9965*da0073e9SAndroid Build Coastguard Worker        run()
9966*da0073e9SAndroid Build Coastguard Worker        # del fc  # This should delete all the references
9967*da0073e9SAndroid Build Coastguard Worker        gc.collect()
9968*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cleared)
9969*da0073e9SAndroid Build Coastguard Worker
9970*da0073e9SAndroid Build Coastguard Worker    @xfailIfPy312
9971*da0073e9SAndroid Build Coastguard Worker    def test_parameter_free(self):
9972*da0073e9SAndroid Build Coastguard Worker        def model_inp_ctr():
9973*da0073e9SAndroid Build Coastguard Worker            param = torch.nn.Parameter(torch.randn(100, 100))
9974*da0073e9SAndroid Build Coastguard Worker
9975*da0073e9SAndroid Build Coastguard Worker            class Mod(torch.nn.Module):
9976*da0073e9SAndroid Build Coastguard Worker                def __init__(self):
9977*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
9978*da0073e9SAndroid Build Coastguard Worker                    self.param = param
9979*da0073e9SAndroid Build Coastguard Worker
9980*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
9981*da0073e9SAndroid Build Coastguard Worker                    return self.param * x[0]
9982*da0073e9SAndroid Build Coastguard Worker
9983*da0073e9SAndroid Build Coastguard Worker            # return param to keep it alive in _test_compile_model_free
9984*da0073e9SAndroid Build Coastguard Worker            return Mod(), (torch.randn(100, 100), param)
9985*da0073e9SAndroid Build Coastguard Worker
9986*da0073e9SAndroid Build Coastguard Worker        self._test_compile_model_free(model_inp_ctr, lambda mod: mod.param)
9987*da0073e9SAndroid Build Coastguard Worker
9988*da0073e9SAndroid Build Coastguard Worker    def test_conditional_list_comp_in_context(self):
9989*da0073e9SAndroid Build Coastguard Worker        def fn(inp):
9990*da0073e9SAndroid Build Coastguard Worker            try:
9991*da0073e9SAndroid Build Coastguard Worker                return [torch.sin(x) for x in inp if x is not None]
9992*da0073e9SAndroid Build Coastguard Worker            except Exception:
9993*da0073e9SAndroid Build Coastguard Worker                pass
9994*da0073e9SAndroid Build Coastguard Worker
9995*da0073e9SAndroid Build Coastguard Worker        inp = [torch.randn(3, 3) for _ in range(3)] + [None]
9996*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
9997*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp)
9998*da0073e9SAndroid Build Coastguard Worker
9999*da0073e9SAndroid Build Coastguard Worker    def test_312_binary_slice_with_graph_break1(self):
10000*da0073e9SAndroid Build Coastguard Worker        l1 = torch.nn.Linear(5, 5)
10001*da0073e9SAndroid Build Coastguard Worker        l2 = torch.nn.Linear(5, 5)
10002*da0073e9SAndroid Build Coastguard Worker
10003*da0073e9SAndroid Build Coastguard Worker        def fn(x):
10004*da0073e9SAndroid Build Coastguard Worker            # causes a graph break with items in the stack
10005*da0073e9SAndroid Build Coastguard Worker            n = torch.nn.Sequential(l1, l2)
10006*da0073e9SAndroid Build Coastguard Worker            out = n[1:](x)
10007*da0073e9SAndroid Build Coastguard Worker            return out
10008*da0073e9SAndroid Build Coastguard Worker
10009*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
10010*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(5, 5))
10011*da0073e9SAndroid Build Coastguard Worker
10012*da0073e9SAndroid Build Coastguard Worker    def test_312_binary_slice_with_graph_break2(self):
10013*da0073e9SAndroid Build Coastguard Worker        class Foo:
10014*da0073e9SAndroid Build Coastguard Worker            def __setitem__(self, key, val):
10015*da0073e9SAndroid Build Coastguard Worker                pass
10016*da0073e9SAndroid Build Coastguard Worker
10017*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, key):
10018*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
10019*da0073e9SAndroid Build Coastguard Worker                return 1
10020*da0073e9SAndroid Build Coastguard Worker
10021*da0073e9SAndroid Build Coastguard Worker        foo = Foo()
10022*da0073e9SAndroid Build Coastguard Worker
10023*da0073e9SAndroid Build Coastguard Worker        def fn(x):
10024*da0073e9SAndroid Build Coastguard Worker            # graph break in a STORE_SLICE instruction
10025*da0073e9SAndroid Build Coastguard Worker            foo[:] = x
10026*da0073e9SAndroid Build Coastguard Worker            # graph break in BINARY_SLICE with has_backedge check
10027*da0073e9SAndroid Build Coastguard Worker            x = x + foo[:]
10028*da0073e9SAndroid Build Coastguard Worker            if x is None:
10029*da0073e9SAndroid Build Coastguard Worker                x = x + 1
10030*da0073e9SAndroid Build Coastguard Worker            else:
10031*da0073e9SAndroid Build Coastguard Worker                x = x + 1
10032*da0073e9SAndroid Build Coastguard Worker            return x
10033*da0073e9SAndroid Build Coastguard Worker
10034*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
10035*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(5, 5))
10036*da0073e9SAndroid Build Coastguard Worker
10037*da0073e9SAndroid Build Coastguard Worker    def test_super_after_graph_break(self):
10038*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Sequential):
10039*da0073e9SAndroid Build Coastguard Worker            def __init__(self, layers):
10040*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
10041*da0073e9SAndroid Build Coastguard Worker                super().__init__(*layers)
10042*da0073e9SAndroid Build Coastguard Worker
10043*da0073e9SAndroid Build Coastguard Worker        def fn(x):
10044*da0073e9SAndroid Build Coastguard Worker            layers = [torch.nn.Linear(3, 3) for _ in range(3)]
10045*da0073e9SAndroid Build Coastguard Worker            mod = Foo(layers)
10046*da0073e9SAndroid Build Coastguard Worker            return mod(x)
10047*da0073e9SAndroid Build Coastguard Worker
10048*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
10049*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(3, 3))
10050*da0073e9SAndroid Build Coastguard Worker
10051*da0073e9SAndroid Build Coastguard Worker    def test_load_fast_and_clear_graph_break(self):
10052*da0073e9SAndroid Build Coastguard Worker        # Can result in a segfault in 3.12+ if LOAD_FAST_AND_CLEAR
10053*da0073e9SAndroid Build Coastguard Worker        # is not handled properly in a graph break
10054*da0073e9SAndroid Build Coastguard Worker        def fn():
10055*da0073e9SAndroid Build Coastguard Worker            out = torch.cat([torch.randn(r, 5) for r in range(3)])
10056*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
10057*da0073e9SAndroid Build Coastguard Worker            out = torch.cat([torch.randn(r, 5) for r in range(3)])
10058*da0073e9SAndroid Build Coastguard Worker            return out
10059*da0073e9SAndroid Build Coastguard Worker
10060*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._dynamo.optimize("eager")(fn)().shape, (3, 5))
10061*da0073e9SAndroid Build Coastguard Worker
10062*da0073e9SAndroid Build Coastguard Worker    def test_raises_importerror1(self):
10063*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
10064*da0073e9SAndroid Build Coastguard Worker        def fn(x):
10065*da0073e9SAndroid Build Coastguard Worker            try:
10066*da0073e9SAndroid Build Coastguard Worker                import some_module_that_surely_does_not_exist
10067*da0073e9SAndroid Build Coastguard Worker
10068*da0073e9SAndroid Build Coastguard Worker                return
10069*da0073e9SAndroid Build Coastguard Worker            except ImportError:
10070*da0073e9SAndroid Build Coastguard Worker                pass
10071*da0073e9SAndroid Build Coastguard Worker            return x.sin()
10072*da0073e9SAndroid Build Coastguard Worker
10073*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(8)
10074*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), x.sin())
10075*da0073e9SAndroid Build Coastguard Worker
10076*da0073e9SAndroid Build Coastguard Worker    def test_raises_importerror2(self):
10077*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
10078*da0073e9SAndroid Build Coastguard Worker        def fn(x):
10079*da0073e9SAndroid Build Coastguard Worker            import some_module_that_surely_does_not_exist
10080*da0073e9SAndroid Build Coastguard Worker
10081*da0073e9SAndroid Build Coastguard Worker            return x + 1
10082*da0073e9SAndroid Build Coastguard Worker
10083*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(8)
10084*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ImportError):
10085*da0073e9SAndroid Build Coastguard Worker            fn(x)
10086*da0073e9SAndroid Build Coastguard Worker
10087*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_cache_move_to_front(self):
10088*da0073e9SAndroid Build Coastguard Worker        def fn(x, const):
10089*da0073e9SAndroid Build Coastguard Worker            return x + const
10090*da0073e9SAndroid Build Coastguard Worker
10091*da0073e9SAndroid Build Coastguard Worker        # dynamic=False forces Dynamo to recompile
10092*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", dynamic=False)
10093*da0073e9SAndroid Build Coastguard Worker
10094*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
10095*da0073e9SAndroid Build Coastguard Worker
10096*da0073e9SAndroid Build Coastguard Worker        # NOTE: assumes that each cache entry is guarded
10097*da0073e9SAndroid Build Coastguard Worker        # on unique Mod instance
10098*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp, 1)
10099*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp, 2)
10100*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp, 3)
10101*da0073e9SAndroid Build Coastguard Worker
10102*da0073e9SAndroid Build Coastguard Worker        c1 = _debug_get_cache_entry_list(fn.__code__)
10103*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(c1), 3)
10104*da0073e9SAndroid Build Coastguard Worker
10105*da0073e9SAndroid Build Coastguard Worker        # move cache entry to front
10106*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp, 2)
10107*da0073e9SAndroid Build Coastguard Worker        c2 = _debug_get_cache_entry_list(fn.__code__)
10108*da0073e9SAndroid Build Coastguard Worker        self.assertIs(c1[1], c2[0])
10109*da0073e9SAndroid Build Coastguard Worker
10110*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
10111*da0073e9SAndroid Build Coastguard Worker    def test_dynamo_cache_invalidate(self):
10112*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
10113*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
10114*da0073e9SAndroid Build Coastguard Worker                super(Mod, self).__init__()
10115*da0073e9SAndroid Build Coastguard Worker                self.fc = torch.nn.Linear(3, 3)
10116*da0073e9SAndroid Build Coastguard Worker
10117*da0073e9SAndroid Build Coastguard Worker            def forward(self, out):
10118*da0073e9SAndroid Build Coastguard Worker                return self.fc(out)
10119*da0073e9SAndroid Build Coastguard Worker
10120*da0073e9SAndroid Build Coastguard Worker        def fn(x, mod):
10121*da0073e9SAndroid Build Coastguard Worker            return mod(x)
10122*da0073e9SAndroid Build Coastguard Worker
10123*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
10124*da0073e9SAndroid Build Coastguard Worker
10125*da0073e9SAndroid Build Coastguard Worker        m1 = Mod()
10126*da0073e9SAndroid Build Coastguard Worker        m2 = Mod()
10127*da0073e9SAndroid Build Coastguard Worker        m3 = Mod()
10128*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
10129*da0073e9SAndroid Build Coastguard Worker
10130*da0073e9SAndroid Build Coastguard Worker        # NOTE: assumes that each cache entry is guarded
10131*da0073e9SAndroid Build Coastguard Worker        # on unique Mod instance
10132*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp, m1)
10133*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp, m2)
10134*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp, m3)
10135*da0073e9SAndroid Build Coastguard Worker
10136*da0073e9SAndroid Build Coastguard Worker        c1 = _debug_get_cache_entry_list(fn.__code__)
10137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(c1), 3)
10138*da0073e9SAndroid Build Coastguard Worker
10139*da0073e9SAndroid Build Coastguard Worker        # move cache entry to front
10140*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp, m2)
10141*da0073e9SAndroid Build Coastguard Worker        c2 = _debug_get_cache_entry_list(fn.__code__)
10142*da0073e9SAndroid Build Coastguard Worker        self.assertIs(c1[1], c2[0])
10143*da0073e9SAndroid Build Coastguard Worker
10144*da0073e9SAndroid Build Coastguard Worker        # delete center of cache
10145*da0073e9SAndroid Build Coastguard Worker        del m3
10146*da0073e9SAndroid Build Coastguard Worker        c3 = _debug_get_cache_entry_list(fn.__code__)
10147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(c3), 2)
10148*da0073e9SAndroid Build Coastguard Worker        self.assertIs(c3[0], c2[0])
10149*da0073e9SAndroid Build Coastguard Worker        self.assertIs(c3[1], c2[2])
10150*da0073e9SAndroid Build Coastguard Worker
10151*da0073e9SAndroid Build Coastguard Worker        # delete end of cache
10152*da0073e9SAndroid Build Coastguard Worker        del m1
10153*da0073e9SAndroid Build Coastguard Worker        c4 = _debug_get_cache_entry_list(fn.__code__)
10154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(c4), 1)
10155*da0073e9SAndroid Build Coastguard Worker        self.assertIs(c4[0], c3[0])
10156*da0073e9SAndroid Build Coastguard Worker
10157*da0073e9SAndroid Build Coastguard Worker        del m2
10158*da0073e9SAndroid Build Coastguard Worker        c5 = _debug_get_cache_entry_list(fn.__code__)
10159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(c5), 0)
10160*da0073e9SAndroid Build Coastguard Worker
10161*da0073e9SAndroid Build Coastguard Worker    def test_grad_none(self):
10162*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
10163*da0073e9SAndroid Build Coastguard Worker            x.grad = torch.abs(y)
10164*da0073e9SAndroid Build Coastguard Worker            x.grad.add_(y)
10165*da0073e9SAndroid Build Coastguard Worker            return torch.abs(y)
10166*da0073e9SAndroid Build Coastguard Worker
10167*da0073e9SAndroid Build Coastguard Worker        y = torch.arange(4).reshape(2, 2).to(torch.float)
10168*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
10169*da0073e9SAndroid Build Coastguard Worker        x.grad = None
10170*da0073e9SAndroid Build Coastguard Worker
10171*da0073e9SAndroid Build Coastguard Worker        z = fn(x, y)
10172*da0073e9SAndroid Build Coastguard Worker        ref_y = torch.clone(z).detach()
10173*da0073e9SAndroid Build Coastguard Worker        ref_x_grad = torch.clone(x.grad).detach()
10174*da0073e9SAndroid Build Coastguard Worker
10175*da0073e9SAndroid Build Coastguard Worker        y = torch.arange(4).reshape(2, 2).to(torch.float)
10176*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
10177*da0073e9SAndroid Build Coastguard Worker        x.grad = None
10178*da0073e9SAndroid Build Coastguard Worker
10179*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
10180*da0073e9SAndroid Build Coastguard Worker        z = opt_fn(x, y)
10181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, ref_y)
10182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, ref_x_grad)
10183*da0073e9SAndroid Build Coastguard Worker
10184*da0073e9SAndroid Build Coastguard Worker    def test_grad_non_none(self):
10185*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
10186*da0073e9SAndroid Build Coastguard Worker            x.grad.add_(y)
10187*da0073e9SAndroid Build Coastguard Worker            return torch.abs(y)
10188*da0073e9SAndroid Build Coastguard Worker
10189*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(2, 2)
10190*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
10191*da0073e9SAndroid Build Coastguard Worker        x.grad = torch.arange(4).reshape(2, 2).to(torch.float)
10192*da0073e9SAndroid Build Coastguard Worker
10193*da0073e9SAndroid Build Coastguard Worker        z = fn(x, y)
10194*da0073e9SAndroid Build Coastguard Worker        ref_y = torch.clone(z).detach()
10195*da0073e9SAndroid Build Coastguard Worker        ref_x_grad = torch.clone(x.grad).detach()
10196*da0073e9SAndroid Build Coastguard Worker
10197*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(2, 2)
10198*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
10199*da0073e9SAndroid Build Coastguard Worker        x.grad = torch.arange(4).reshape(2, 2).to(torch.float)
10200*da0073e9SAndroid Build Coastguard Worker
10201*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounterWithBackend("eager")
10202*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend=cnt)
10203*da0073e9SAndroid Build Coastguard Worker        z = opt_fn(x, y)
10204*da0073e9SAndroid Build Coastguard Worker
10205*da0073e9SAndroid Build Coastguard Worker        # Ensure that the generated graph returns only one output. We want the
10206*da0073e9SAndroid Build Coastguard Worker        # add_ on the grad to be part of the graph itself, so that inductor can
10207*da0073e9SAndroid Build Coastguard Worker        # theoretically move the add_ and resutling copy_ nodes at the right
10208*da0073e9SAndroid Build Coastguard Worker        # place to free memory.
10209*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(cnt.graphs[0].graph.nodes)[-1].all_input_nodes), 1)
10210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, ref_y)
10211*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, ref_x_grad)
10212*da0073e9SAndroid Build Coastguard Worker
10213*da0073e9SAndroid Build Coastguard Worker    def test_new_with_int_list(self):
10214*da0073e9SAndroid Build Coastguard Worker        # Make sure torch.Tensor.new(int argument list) behaves the same on dynamo.
10215*da0073e9SAndroid Build Coastguard Worker        def fn(x):
10216*da0073e9SAndroid Build Coastguard Worker            return x.new(*x.size()) + 5
10217*da0073e9SAndroid Build Coastguard Worker
10218*da0073e9SAndroid Build Coastguard Worker        optfn = torch.compile(backend="eager")(fn)
10219*da0073e9SAndroid Build Coastguard Worker
10220*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(10).view(2, 5)
10221*da0073e9SAndroid Build Coastguard Worker
10222*da0073e9SAndroid Build Coastguard Worker        expected = fn(x)
10223*da0073e9SAndroid Build Coastguard Worker        actual = optfn(x)
10224*da0073e9SAndroid Build Coastguard Worker
10225*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected.dtype, actual.dtype)
10226*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected.shape, actual.shape)
10227*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected.stride(), actual.stride())
10228*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected.storage_offset(), actual.storage_offset())
10229*da0073e9SAndroid Build Coastguard Worker
10230*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=True)
10231*da0073e9SAndroid Build Coastguard Worker    def test_hasattr_nn_module_guard(self):
10232*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
10233*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
10234*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10235*da0073e9SAndroid Build Coastguard Worker                self.a = torch.nn.Linear(3, 3)
10236*da0073e9SAndroid Build Coastguard Worker
10237*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
10238*da0073e9SAndroid Build Coastguard Worker                if hasattr(self, "a"):
10239*da0073e9SAndroid Build Coastguard Worker                    return self.a(x)
10240*da0073e9SAndroid Build Coastguard Worker                else:
10241*da0073e9SAndroid Build Coastguard Worker                    return x
10242*da0073e9SAndroid Build Coastguard Worker
10243*da0073e9SAndroid Build Coastguard Worker        m = M()
10244*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
10245*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
10246*da0073e9SAndroid Build Coastguard Worker
10247*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(backend="eager")(m)
10248*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
10249*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
10250*da0073e9SAndroid Build Coastguard Worker
10251*da0073e9SAndroid Build Coastguard Worker    def test_ordered_dict_move_to_end(self):
10252*da0073e9SAndroid Build Coastguard Worker        d = {
10253*da0073e9SAndroid Build Coastguard Worker            "foo": 1,
10254*da0073e9SAndroid Build Coastguard Worker            "bar": 2,
10255*da0073e9SAndroid Build Coastguard Worker        }
10256*da0073e9SAndroid Build Coastguard Worker
10257*da0073e9SAndroid Build Coastguard Worker        d = collections.OrderedDict(d)
10258*da0073e9SAndroid Build Coastguard Worker        d.move_to_end("foo")
10259*da0073e9SAndroid Build Coastguard Worker
10260*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
10261*da0073e9SAndroid Build Coastguard Worker        def fn(x, d):
10262*da0073e9SAndroid Build Coastguard Worker            return x * d["foo"] * d["bar"]
10263*da0073e9SAndroid Build Coastguard Worker
10264*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4), d)
10265*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10266*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(4), d)
10267*da0073e9SAndroid Build Coastguard Worker
10268*da0073e9SAndroid Build Coastguard Worker    def test_defaultdict(self):
10269*da0073e9SAndroid Build Coastguard Worker        d = collections.defaultdict()
10270*da0073e9SAndroid Build Coastguard Worker        d["foo"] = 1
10271*da0073e9SAndroid Build Coastguard Worker        d["bar"] = 2
10272*da0073e9SAndroid Build Coastguard Worker
10273*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
10274*da0073e9SAndroid Build Coastguard Worker        def fn(x, d):
10275*da0073e9SAndroid Build Coastguard Worker            return x * d["foo"] * d["bar"]
10276*da0073e9SAndroid Build Coastguard Worker
10277*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4), d)
10278*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10279*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(4), d)
10280*da0073e9SAndroid Build Coastguard Worker
10281*da0073e9SAndroid Build Coastguard Worker    def test_custom_dict(self):
10282*da0073e9SAndroid Build Coastguard Worker        class MyDict(dict):
10283*da0073e9SAndroid Build Coastguard Worker            pass
10284*da0073e9SAndroid Build Coastguard Worker
10285*da0073e9SAndroid Build Coastguard Worker        d = {
10286*da0073e9SAndroid Build Coastguard Worker            "foo": 1,
10287*da0073e9SAndroid Build Coastguard Worker            "bar": 2,
10288*da0073e9SAndroid Build Coastguard Worker        }
10289*da0073e9SAndroid Build Coastguard Worker
10290*da0073e9SAndroid Build Coastguard Worker        d = MyDict(d)
10291*da0073e9SAndroid Build Coastguard Worker
10292*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
10293*da0073e9SAndroid Build Coastguard Worker        def fn(x, d):
10294*da0073e9SAndroid Build Coastguard Worker            return x * d["foo"] * d["bar"]
10295*da0073e9SAndroid Build Coastguard Worker
10296*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4), d)
10297*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10298*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(4), d)
10299*da0073e9SAndroid Build Coastguard Worker
10300*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires cuda")
10301*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(
10302*da0073e9SAndroid Build Coastguard Worker        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
10303*da0073e9SAndroid Build Coastguard Worker    )
10304*da0073e9SAndroid Build Coastguard Worker    @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True)
10305*da0073e9SAndroid Build Coastguard Worker    def test_interpolate_propagate_real_tensors(self):
10306*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
10307*da0073e9SAndroid Build Coastguard Worker        def f(mask, box):
10308*da0073e9SAndroid Build Coastguard Worker            # u0, u1 = mask.tolist()
10309*da0073e9SAndroid Build Coastguard Worker            mask = torch.randn(1, 1, 30, 30, device="cuda")
10310*da0073e9SAndroid Build Coastguard Worker            h, w = box.tolist()
10311*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.interpolate(
10312*da0073e9SAndroid Build Coastguard Worker                mask, (h, w), mode="bilinear", align_corners=False
10313*da0073e9SAndroid Build Coastguard Worker            )
10314*da0073e9SAndroid Build Coastguard Worker
10315*da0073e9SAndroid Build Coastguard Worker        f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda"))
10316*da0073e9SAndroid Build Coastguard Worker
10317*da0073e9SAndroid Build Coastguard Worker    def test_custom_iter_dict(self):
10318*da0073e9SAndroid Build Coastguard Worker        class ReversedDict(dict):
10319*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
10320*da0073e9SAndroid Build Coastguard Worker                return reversed(list(self.keys()))
10321*da0073e9SAndroid Build Coastguard Worker
10322*da0073e9SAndroid Build Coastguard Worker        d = {
10323*da0073e9SAndroid Build Coastguard Worker            "foo": 1,
10324*da0073e9SAndroid Build Coastguard Worker            "bar": 2,
10325*da0073e9SAndroid Build Coastguard Worker        }
10326*da0073e9SAndroid Build Coastguard Worker
10327*da0073e9SAndroid Build Coastguard Worker        d = ReversedDict(d)
10328*da0073e9SAndroid Build Coastguard Worker
10329*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
10330*da0073e9SAndroid Build Coastguard Worker        def fn(x, d):
10331*da0073e9SAndroid Build Coastguard Worker            return x * d["foo"] * d["bar"]
10332*da0073e9SAndroid Build Coastguard Worker
10333*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4), d)
10334*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10335*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(4), d)
10336*da0073e9SAndroid Build Coastguard Worker
10337*da0073e9SAndroid Build Coastguard Worker    def test_custom_keys_iter_dict(self):
10338*da0073e9SAndroid Build Coastguard Worker        class ReversedDict(dict):
10339*da0073e9SAndroid Build Coastguard Worker            def keys(self):
10340*da0073e9SAndroid Build Coastguard Worker                return ["bar", "foo"]
10341*da0073e9SAndroid Build Coastguard Worker
10342*da0073e9SAndroid Build Coastguard Worker        d = {
10343*da0073e9SAndroid Build Coastguard Worker            "foo": 1,
10344*da0073e9SAndroid Build Coastguard Worker            "bar": 2,
10345*da0073e9SAndroid Build Coastguard Worker        }
10346*da0073e9SAndroid Build Coastguard Worker
10347*da0073e9SAndroid Build Coastguard Worker        d = ReversedDict(d)
10348*da0073e9SAndroid Build Coastguard Worker
10349*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
10350*da0073e9SAndroid Build Coastguard Worker        def fn(x, d):
10351*da0073e9SAndroid Build Coastguard Worker            return x * d["foo"] * d["bar"]
10352*da0073e9SAndroid Build Coastguard Worker
10353*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4), d)
10354*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
10355*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(4), d)
10356*da0073e9SAndroid Build Coastguard Worker
10357*da0073e9SAndroid Build Coastguard Worker    def test_dict_guard_on_keys_order(self):
10358*da0073e9SAndroid Build Coastguard Worker        d = {
10359*da0073e9SAndroid Build Coastguard Worker            2: 4,
10360*da0073e9SAndroid Build Coastguard Worker            3: 5,
10361*da0073e9SAndroid Build Coastguard Worker        }
10362*da0073e9SAndroid Build Coastguard Worker
10363*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
10364*da0073e9SAndroid Build Coastguard Worker
10365*da0073e9SAndroid Build Coastguard Worker        def fn(x, d):
10366*da0073e9SAndroid Build Coastguard Worker            for key, value in d.items():
10367*da0073e9SAndroid Build Coastguard Worker                x = x * key + value
10368*da0073e9SAndroid Build Coastguard Worker            return x
10369*da0073e9SAndroid Build Coastguard Worker
10370*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend=cnts)
10371*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(4), d)
10372*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(4), d)
10373*da0073e9SAndroid Build Coastguard Worker        # No recompilation
10374*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
10375*da0073e9SAndroid Build Coastguard Worker
10376*da0073e9SAndroid Build Coastguard Worker        # move 2 to the end
10377*da0073e9SAndroid Build Coastguard Worker        d[2] = d.pop(2)
10378*da0073e9SAndroid Build Coastguard Worker
10379*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
10380*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, d)
10381*da0073e9SAndroid Build Coastguard Worker        # Check recompilation
10382*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
10383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, fn(x, d))
10384*da0073e9SAndroid Build Coastguard Worker
10385*da0073e9SAndroid Build Coastguard Worker    def test_dict_guard_on_keys_order2(self):
10386*da0073e9SAndroid Build Coastguard Worker        d = {
10387*da0073e9SAndroid Build Coastguard Worker            2: 4,
10388*da0073e9SAndroid Build Coastguard Worker            3: 5,
10389*da0073e9SAndroid Build Coastguard Worker        }
10390*da0073e9SAndroid Build Coastguard Worker
10391*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
10392*da0073e9SAndroid Build Coastguard Worker
10393*da0073e9SAndroid Build Coastguard Worker        def fn(x, d):
10394*da0073e9SAndroid Build Coastguard Worker            for key in d:
10395*da0073e9SAndroid Build Coastguard Worker                value = d[key]
10396*da0073e9SAndroid Build Coastguard Worker                x = x * key + value
10397*da0073e9SAndroid Build Coastguard Worker            return x
10398*da0073e9SAndroid Build Coastguard Worker
10399*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend=cnts)
10400*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(4), d)
10401*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(4), d)
10402*da0073e9SAndroid Build Coastguard Worker        # No recompilation
10403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
10404*da0073e9SAndroid Build Coastguard Worker
10405*da0073e9SAndroid Build Coastguard Worker        # move 2 to the end
10406*da0073e9SAndroid Build Coastguard Worker        d[2] = d.pop(2)
10407*da0073e9SAndroid Build Coastguard Worker
10408*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
10409*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, d)
10410*da0073e9SAndroid Build Coastguard Worker        # Check recompilation
10411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
10412*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, fn(x, d))
10413*da0073e9SAndroid Build Coastguard Worker
10414*da0073e9SAndroid Build Coastguard Worker    def test_contains_dunder_dict(self):
10415*da0073e9SAndroid Build Coastguard Worker        class UserDefined:
10416*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
10417*da0073e9SAndroid Build Coastguard Worker                self.a = 3
10418*da0073e9SAndroid Build Coastguard Worker                self.b = 5
10419*da0073e9SAndroid Build Coastguard Worker
10420*da0073e9SAndroid Build Coastguard Worker            def run(self, x):
10421*da0073e9SAndroid Build Coastguard Worker                if "a" in self.__dict__:
10422*da0073e9SAndroid Build Coastguard Worker                    x = x * self.a
10423*da0073e9SAndroid Build Coastguard Worker                if "b" in self.__dict__:
10424*da0073e9SAndroid Build Coastguard Worker                    x = x * self.b
10425*da0073e9SAndroid Build Coastguard Worker                self.c = 7
10426*da0073e9SAndroid Build Coastguard Worker                if "c" in self.__dict__:
10427*da0073e9SAndroid Build Coastguard Worker                    x = x * self.c
10428*da0073e9SAndroid Build Coastguard Worker                return x * self.__dict__.get("a") * self.__dict__.get("z", 2)
10429*da0073e9SAndroid Build Coastguard Worker
10430*da0073e9SAndroid Build Coastguard Worker        obj = UserDefined()
10431*da0073e9SAndroid Build Coastguard Worker
10432*da0073e9SAndroid Build Coastguard Worker        def fn(x):
10433*da0073e9SAndroid Build Coastguard Worker            return obj.run(x)
10434*da0073e9SAndroid Build Coastguard Worker
10435*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
10436*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
10437*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
10438*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
10439*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
10440*da0073e9SAndroid Build Coastguard Worker
10441*da0073e9SAndroid Build Coastguard Worker    def test_module_dunder_dict(self):
10442*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
10443*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
10444*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10445*da0073e9SAndroid Build Coastguard Worker                self.foo = 1
10446*da0073e9SAndroid Build Coastguard Worker                self.bar = 2
10447*da0073e9SAndroid Build Coastguard Worker                self.baz = 3
10448*da0073e9SAndroid Build Coastguard Worker
10449*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
10450*da0073e9SAndroid Build Coastguard Worker                if "foo" in self.__dict__:
10451*da0073e9SAndroid Build Coastguard Worker                    return x * self.bar
10452*da0073e9SAndroid Build Coastguard Worker                return x * self.baz
10453*da0073e9SAndroid Build Coastguard Worker
10454*da0073e9SAndroid Build Coastguard Worker        mod = MyModule()
10455*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
10456*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
10457*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod(x), opt_mod(x))
10458*da0073e9SAndroid Build Coastguard Worker
10459*da0073e9SAndroid Build Coastguard Worker
10460*da0073e9SAndroid Build Coastguard Workerclass TestTracer(JitTestCase):
10461*da0073e9SAndroid Build Coastguard Worker    def test_jit_save(self):
10462*da0073e9SAndroid Build Coastguard Worker        def fn():
10463*da0073e9SAndroid Build Coastguard Worker            class Foo(torch.nn.Module):
10464*da0073e9SAndroid Build Coastguard Worker                def __init__(self):
10465*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
10466*da0073e9SAndroid Build Coastguard Worker                    self.a = 3
10467*da0073e9SAndroid Build Coastguard Worker
10468*da0073e9SAndroid Build Coastguard Worker                @torch.jit.export
10469*da0073e9SAndroid Build Coastguard Worker                def __getstate__(self):
10470*da0073e9SAndroid Build Coastguard Worker                    return (3, self.training)
10471*da0073e9SAndroid Build Coastguard Worker
10472*da0073e9SAndroid Build Coastguard Worker                @torch.jit.export
10473*da0073e9SAndroid Build Coastguard Worker                def __setstate__(self, state):
10474*da0073e9SAndroid Build Coastguard Worker                    self.a = state[0]
10475*da0073e9SAndroid Build Coastguard Worker                    self.training = state[1]
10476*da0073e9SAndroid Build Coastguard Worker
10477*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
10478*da0073e9SAndroid Build Coastguard Worker                    return x + self.a
10479*da0073e9SAndroid Build Coastguard Worker
10480*da0073e9SAndroid Build Coastguard Worker            f = Foo()
10481*da0073e9SAndroid Build Coastguard Worker
10482*da0073e9SAndroid Build Coastguard Worker            return torch.jit.trace(f, (torch.rand(3, 4),))
10483*da0073e9SAndroid Build Coastguard Worker
10484*da0073e9SAndroid Build Coastguard Worker        fn()
10485*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
10486*da0073e9SAndroid Build Coastguard Worker        opt_fn()
10487*da0073e9SAndroid Build Coastguard Worker
10488*da0073e9SAndroid Build Coastguard Worker
10489*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
10490*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
10491*da0073e9SAndroid Build Coastguard Worker
10492*da0073e9SAndroid Build Coastguard Worker    run_tests()
10493