xref: /aosp_15_r20/external/pytorch/test/dynamo/test_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2# flake8: noqa: E731, C405, F811, C418, C417
3import collections
4import functools
5import inspect
6import itertools
7import math
8import operator
9import random
10import sys
11import unittest
12from dataclasses import dataclass, field
13from typing import Any, Dict, List, NamedTuple
14from unittest.mock import patch
15
16import numpy as np
17
18import torch
19import torch._dynamo.test_case
20import torch._dynamo.testing
21from torch import sub
22from torch._dynamo.testing import (
23    CompileCounterWithBackend,
24    EagerAndRecordGraphs,
25    normalize_gm,
26)
27from torch._dynamo.utils import ifdynstaticdefault, same
28from torch._dynamo.variables import ConstantVariable
29from torch._dynamo.variables.lists import RangeVariable
30from torch.nn import functional as F
31from torch.testing._internal.common_utils import (
32    disable_translation_validation_if_dynamic_shapes,
33    instantiate_parametrized_tests,
34    parametrize,
35)
36
37# Defines all the kernels for tests
38from torch.testing._internal.triton_utils import *  # noqa: F403
39
40
41d = torch.ones(10, 10)
42e = torch.nn.Linear(10, 10)
43flag = True
44
45
46class CustomDictSubclass(collections.OrderedDict):
47    pass
48
49
50clip01 = functools.partial(torch.clip, min=0.0, max=1.0)
51
52
53def constant3(a, b):
54    return a - b + (1.0 + 2)
55
56
57_variable = 0
58
59
60def update_global(x):
61    global _variable
62    _variable += 1
63    # Check that updated global variable value is picked up
64    return x * _variable
65
66
67def func_with_default(a, b, some_default_arg=True):
68    if some_default_arg:
69        return a - b
70
71
72def make_test(fn=None, expected_frame_count=1):
73    if fn is None:
74        return lambda fn: make_test(fn, expected_frame_count=expected_frame_count)
75
76    nargs = len(inspect.signature(fn).parameters)
77
78    def test_fn(self):
79        return torch._dynamo.testing.standard_test(
80            self,
81            fn=fn,
82            nargs=nargs,
83            expected_frame_count=expected_frame_count,
84        )
85
86    return test_fn
87
88
89class MyCls:
90    a = 1
91
92
93@torch.jit.script_if_tracing
94def inline_script_if_tracing(x):
95    return x + 1.2
96
97
98@torch.jit.ignore
99def inline_ignore(x):
100    return x + 3.4
101
102
103@torch.jit.unused
104def inline_unused(x):
105    return x + 5.6
106
107
108@functools.lru_cache
109def inline_lru_cache_fn_with_default_args(x, y, _=None):
110    return torch.sin(x * y)
111
112
113@torch.jit.script_if_tracing
114def inline_script_if_tracing_fn_with_default_args(x, y, c=1.2):
115    return torch.cos(x * y) + c
116
117
118class FunctionTests(torch._dynamo.test_case.TestCase):
119    @make_test
120    def test_inline_jit_annotations(x):
121        x = inline_script_if_tracing(x)
122        x = inline_ignore(x)
123        x = inline_unused(x)
124        return
125
126    @make_test
127    def test_inline_script_if_tracing_fn_with_default_args(a, b):
128        return inline_script_if_tracing_fn_with_default_args(a, b)
129
130    @make_test
131    def test_inline_lru_cache_fn_with_default_args(a, b):
132        return inline_lru_cache_fn_with_default_args(a, 2, b)
133
134    @make_test
135    def test_add(a, b):
136        return a + b
137
138    @make_test
139    def test_add_(a, b):
140        a_copy = torch.tensor(a)
141        return a_copy.add_(b, alpha=5.0)
142
143    @make_test
144    def test_addcdiv(a, b, c):
145        # dynamo decomposes this to avoid a graph break when
146        # the value kwarg is populated
147        return torch.addcdiv(a, b, c, value=5.0)
148
149    @make_test
150    def test_addcdiv_(a, b, c):
151        a_copy = torch.tensor(a)
152        return a_copy.addcdiv_(b, c, value=5.0)
153
154    @make_test
155    def test_is_not_null(a, b):
156        if a is not None and b is not None:
157            return a + b
158
159    def test_foreach_lerp_(self):
160        def fn(x, y, s):
161            return torch._foreach_lerp_(x, y, s)
162
163        cnt = torch._dynamo.testing.CompileCounter()
164
165        fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
166        expected = fn(
167            [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14],
168            [torch.ones(2, 2), torch.ones(2, 2)],
169            torch.tensor(0.5),
170        )
171
172        actual = fn_opt(
173            [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14],
174            [torch.ones(2, 2), torch.ones(2, 2)],
175            torch.tensor(0.5),
176        )
177        self.assertTrue(same(expected, actual))
178
179    def test_broadcast_foreach_pow(self):
180        from torch._dynamo.utils import same
181
182        def fn(x, y):
183            return torch._foreach_pow(x, y)
184
185        cnt = torch._dynamo.testing.CompileCounter()
186
187        fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
188        inps = (torch.tensor(0.80), [torch.tensor(3.4), torch.tensor(7.8)])
189
190        actual = fn_opt(*inps)
191        expected = fn(*inps)
192        self.assertTrue(same(actual, expected))
193        self.assertTrue(cnt.frame_count, 1)
194
195    def test_addcmul_(self):
196        from copy import deepcopy
197
198        from torch._dynamo.utils import same
199
200        def fn(x, y, z, s):
201            return x.addcmul_(y, z, value=s)
202
203        cnt = torch._dynamo.testing.CompileCounter()
204        fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
205        inps = (
206            torch.ones(2, 2),
207            torch.ones(2, 2) + 1,
208            torch.rand(2, 2),
209            torch.tensor(0.3),
210        )
211        inps_2 = deepcopy(inps)
212        actual = fn_opt(*inps)
213        expected = fn(*inps_2)
214        self.assertTrue(same(actual, expected))
215        self.assertEqual(cnt.frame_count, 1)
216
217    @make_test
218    def test_functools_partial(a, b):
219        return clip01(a + b)
220
221    @make_test
222    def test_itertools_product(a, b):
223        v = a
224        for x, i in itertools.product([a, b], [1, 2]):
225            v = v + x * i
226        return v
227
228    @make_test
229    def test_itertools_chain(a, b):
230        v = a
231        for x in itertools.chain([a, b], [1, 2]):
232            v = v + x
233        return v
234
235    @make_test
236    def test_itertools_chain_from_iterable(a, b):
237        v = a
238        for x in itertools.chain.from_iterable([[a, b], [1, 2]]):
239            v = v + x
240        return v
241
242    def test_itertools_reconstruct(self):
243        def fn(a):
244            it1 = itertools.repeat(1)
245            it2 = itertools.count(2)
246            for _ in range(3):
247                a += next(it1)
248                a += next(it2)
249            return it1, it2, a
250
251        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
252        i1, i2, a = fn(torch.ones(3, 3))
253        it1, it2, b = opt_fn(torch.ones(3, 3))
254        self.assertEqual(next(i1), next(it1))
255        self.assertEqual(next(i2), next(it2))
256        self.assertEqual(a, b)
257
258    @make_test
259    def test_obj_eq(a, b):
260        v = a + b
261        if MyCls() == None:  # noqa: E711
262            return -1
263        if MyCls() != None:  # noqa: E711
264            v = v.sin()
265        if MyCls() == MyCls():
266            return -2
267        if MyCls() != MyCls():
268            return v + 1
269        return -3
270
271    @make_test
272    def test_cls_eq(a, b):
273        v = a + b
274        if MyCls == None:  # noqa: E711
275            return -1
276        if MyCls != None:  # noqa: E711
277            v = v.sin()
278        if MyCls != MyCls:
279            return -2
280        if MyCls == MyCls:
281            return v + 1
282        return -3
283
284    @make_test
285    def test_obj_is(a, b):
286        v = a + b
287        if MyCls() is None:  # noqa: E711
288            return -1
289        if MyCls() is not None:  # noqa: E711
290            v = v.sin()
291        if MyCls() is MyCls():
292            return -2
293        if MyCls() is not MyCls():
294            return v + 1
295        return -3
296
297    @make_test
298    def test_cls_is(a, b):
299        v = a + b
300        if MyCls is None:  # noqa: E711
301            return -1
302        if MyCls is not None:  # noqa: E711
303            v = v.sin()
304        if MyCls is not MyCls:
305            return -2
306        if MyCls is MyCls:
307            return v + 1
308        return -3
309
310    @make_test
311    def test_itertools_combinations(a, b):
312        combs = []
313        for size in itertools.combinations((1, 2, 3, 4), 2):
314            combs.append(torch.ones(size))
315        return combs
316
317    @make_test
318    def test_np_iinfo(a):
319        max_dim = np.iinfo(np.int16).max
320        return a + max_dim
321
322    @make_test
323    def test_np_finfo(a):
324        min_dim = np.finfo(np.float32).min
325        return a + min_dim
326
327    @make_test
328    def test_constant1(a, b, c):
329        return a - b * c + 1.0
330
331    @make_test
332    def test_constant2(a, b, c):
333        return a - b * c + 1
334
335    @make_test
336    def test_constant3(a):
337        b = 1
338        c = 2
339        d = 3
340        return b + c - d + a
341
342    @make_test
343    def test_constant4(a, b):
344        c = 2
345        d = 3
346        if c > d:
347            return a - b
348        return b - a
349
350    @make_test
351    def test_cls_hasattr(self, x):
352        if hasattr(MyCls, "a"):
353            x = x + 1
354        if hasattr(MyCls, "b"):
355            x = x + 2
356        return x
357
358    @make_test
359    def test_finfo(a, b):
360        if torch.iinfo(torch.int32).bits == 32:
361            return torch.finfo(a.dtype).min * b
362
363    @make_test
364    def test_globalfn(a, b):
365        return sub(a, b)
366
367    @make_test
368    def test_viatorch(a, b):
369        return torch.sub(a, b)
370
371    @make_test
372    def test_viamethod(a, b):
373        return a.sub(b)
374
375    @make_test
376    def test_indirect1(a, b):
377        t = a.sub
378        return t(b)
379
380    @make_test
381    def test_indirect2(a, b):
382        t = a.sub
383        args = (b,)
384        return t(*args)
385
386    @make_test
387    def test_indirect3(a, b):
388        t = a.sub
389        args = (b,)
390        kwargs = {}
391        return t(*args, **kwargs)
392
393    @make_test
394    def test_methodcall1(a, b, c):
395        return constant3(a, b) * c
396
397    @make_test
398    def test_methodcall2(a, b):
399        return constant3(a=b, b=a) + 1
400
401    @make_test
402    def test_methodcall3(a, b):
403        return constant3(a, b=1.0) + b
404
405    def test_is_integer(self):
406        @torch.compile(backend="eager", fullgraph=True)
407        def forward(t, m):
408            return 2 * t if m.is_integer() else t
409
410        t = torch.tensor([1])
411        self.assertEqual(forward(t, 1.0).item(), 2)
412        self.assertEqual(forward(t, 1.5).item(), 1)
413
414    @parametrize(
415        "method, num_type",
416        (
417            ("as_integer_ratio", int),
418            ("bit_length", int),
419            ("conjugate", int),
420            ("as_integer_ratio", float),
421            ("conjugate", float),
422            ("hex", float),
423            ("is_integer", float),
424        ),
425    )
426    def test_number_method(self, method, num_type):
427        def forward(t, m):
428            return 2 * t if getattr(m, method)() else t
429
430        wrapped = torch.compile(backend="eager", fullgraph=True)(forward)
431
432        for i in (0, 1, 2.5):
433            m = num_type(i)
434            t = torch.tensor([1])
435            actual = wrapped(t, m)
436            expected = forward(t, m)
437            self.assertEqual(actual, expected)
438
439    @make_test
440    def test_device_constant(a):
441        return a + torch.ones(1, device=torch.device("cpu"))
442
443    @make_test
444    def test_tuple1(a, b):
445        args = (a, b)
446        return sub(*args)
447
448    @make_test
449    def test_tuple2(a, b):
450        args = [a, b]
451        return sub(*args)
452
453    @make_test
454    def test_is_in_onnx_export(x, y):
455        if torch.onnx.is_in_onnx_export():
456            return x - 1
457        else:
458            return y + 1
459
460    @make_test
461    def test_is_fx_tracing(x, y):
462        if torch.fx._symbolic_trace.is_fx_tracing():
463            return x - 1
464        else:
465            return y + 1
466
467    @make_test
468    def test_listarg1(a, b):
469        return torch.cat([a, b])
470
471    @make_test
472    def test_listarg2(a, b):
473        return torch.cat((a, b), dim=0)
474
475    @make_test
476    def test_listarg3(a, b):
477        kwargs = {"tensors": (a, b), "dim": 0}
478        return torch.cat(**kwargs)
479
480    @make_test
481    def test_listarg4(a, b):
482        return torch.cat(tensors=[a, b], dim=0)
483
484    @make_test
485    def test_listarg5(a, b):
486        args = [(a, b)]
487        kwargs = {"dim": 0}
488        return torch.cat(*args, **kwargs)
489
490    def test_list_slice(self):
491        class Mock:
492            def __init__(self):
493                self.ets = []
494                self.counter = 0
495
496            @torch.compile(backend="eager")
497            def run(self, x):
498                self.ets = self.ets[-3:]
499                self.ets.append(x)
500                return torch.sin(x)
501
502        mock = Mock()
503        mock.run(torch.randn(4))
504        self.assertEqual(len(mock.ets), 1)
505
506    @make_test
507    def test_deque(a, b):
508        d = collections.deque([a, b])
509        d.append(a + 1)
510        d.extend([a, b])
511        d.insert(0, "foo")
512        tmp = d.pop()
513
514        another_deque = collections.deque([tmp])
515        d.extendleft(another_deque)
516        another_deque.clear()
517        d.extend(another_deque)
518
519        d[2] = "setitem"
520        d = d.copy()
521        d.append(d.popleft())
522
523        empty = collections.deque()
524        d.extend(empty)
525
526        return d
527
528    @make_test
529    def test_slice1(a):
530        return a[5]
531
532    @make_test
533    def test_slice2(a):
534        return a[:5]
535
536    @make_test
537    def test_slice3(a):
538        return a[5:]
539
540    @make_test
541    def test_slice4(a):
542        return a[2:5]
543
544    @make_test
545    def test_slice5(a):
546        return a[::2]
547
548    @make_test
549    def test_slice6(a):
550        return torch.unsqueeze(a, 0)[:, 2:]
551
552    @make_test
553    def test_range1(a):
554        return torch.tensor(range(a.size(0)))
555
556    @make_test
557    def test_range2(x, y):
558        r = x + y
559        for i in range(x.size(0) + 2):
560            r = r / y
561        return r
562
563    @make_test
564    def test_unpack1(a):
565        a, b = a[:5], a[5:]
566        return a - b
567
568    @make_test
569    def test_unpack2(a):
570        packed = [a[:5], a[5:]]
571        a, b = packed
572        return a - b
573
574    @make_test
575    def test_unpack3(a):
576        packed = (a[:5], a[5:])
577        a, b = packed
578        return a - b
579
580    @make_test
581    def test_fn_with_self_set(a, b):
582        # avg_pool2d is an odd one with __self__ set
583        return F.avg_pool2d(
584            torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1
585        )
586
587    @make_test
588    def test_return_tuple1(a, b):
589        return (a - b, b - a, a, b)
590
591    @make_test
592    def test_globalvar(a, b):
593        return a - b + d
594
595    @make_test
596    def test_globalmodule(x):
597        return e(x)
598
599    @make_test
600    def test_inline_with_default(a, b, c):
601        return func_with_default(a, b) * c
602
603    @make_test
604    def test_inner_function(x):
605        def fn(x):
606            return torch.add(x, x)
607
608        return fn(x)
609
610    @make_test
611    def test_transpose_for_scores(x):
612        new_x_shape = x.size()[:-1] + (2, 5)
613        x = x.view(*new_x_shape)
614        return x.permute(0, 2, 1)
615
616    @make_test
617    def test_return_tuple2(x):
618        return (torch.add(x, x), x)
619
620    @make_test
621    def test_load_global_bool(x):
622        if flag:
623            return torch.add(x, x)
624        else:
625            return x
626
627    @make_test
628    def test_len_tensor(x):
629        z = len(x)
630        return torch.add(x, z)
631
632    @make_test
633    def test_len_constant_list(x):
634        z = len([1, 2, 3])
635        return torch.add(x, z)
636
637    @make_test
638    def test_len_constant_dict(x):
639        z = len({"foo": "bar"})
640        return torch.add(x, z)
641
642    @make_test
643    def test_dict_copy(x):
644        z = dict({"foo": x + 1})
645        return z
646
647    @make_test
648    def test_dict_keys(x):
649        d = {3: x}
650        keys = d.keys()
651        d[4] = x + 1
652        d2 = {3: 2, 4: "aa"}
653        return 3 in keys, 4 in keys, 5 in keys, d2.keys() == keys
654
655    @make_test
656    def test_dict_values(x):
657        d = {3: x}
658        values = d.values()
659        d[3] = x + 1
660        d[4] = x + 2
661        return len(values)
662
663    @make_test
664    def test_dict_setdefault1(x):
665        d = {"a": 1, "b": 2}
666        d.setdefault("a", 10)
667        if d["a"] == 1:
668            return x + 1
669        else:
670            return x - 1
671
672    @make_test
673    def test_dict_setdefault2(x):
674        d = {"a": 1, "b": 2}
675        d.setdefault("c", 10)
676        if d["c"] == 10:
677            return x + 1
678        else:
679            return x - 1
680
681    @make_test
682    def test_dict_setdefault3(x):
683        d = {"a": 1, "b": 2}
684        d.setdefault("c")
685        if d["c"] is None:
686            return x + 1
687        else:
688            return x - 1
689
690    @make_test
691    def test_defaultdict_setdefault1(x):
692        d = collections.defaultdict.fromkeys("a", "b")
693        d["a"] = 1
694        d["b"] = 2
695        d.setdefault("a", 10)
696        if d["a"] == 1:
697            return x + 1
698        else:
699            return x - 1
700
701    @make_test
702    def test_defaultdict_setdefault2(x):
703        d = collections.defaultdict.fromkeys("a", "b")
704        d["a"] = 1
705        d["b"] = 2
706        d.setdefault("c", 10)
707        if d["c"] == 10:
708            return x + 1
709        else:
710            return x - 1
711
712    @make_test
713    def test_defaultdict_setdefault3(x):
714        d = collections.defaultdict.fromkeys("a", "b")
715        d["a"] = 1
716        d["b"] = 2
717        d.setdefault("c")
718        if d["c"] is None:
719            return x + 1
720        else:
721            return x - 1
722
723    def test_dict_id_guard(self):
724        d1 = collections.OrderedDict({"a": 2})
725        d2 = d1
726
727        def fn(x):
728            # Iteration forces DictGuardManager
729            for k in d1:
730                x = x * d1[k] * d2[k]
731            return x
732
733        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
734        x = torch.randn(4)
735        self.assertEqual(fn(x), opt_fn(x))
736
737    @make_test
738    def test_callable_lambda(x):
739        if callable(lambda x: True):
740            return x + 1
741        else:
742            return x - 1
743
744    @make_test
745    def test_callable_torch(x):
746        if callable(torch.abs):
747            return x + 1
748        else:
749            return x - 1
750
751    @make_test
752    def test_callable_builtin(x):
753        if callable(sum):
754            return x + 1
755        else:
756            return x - 1
757
758    def test_callable_class(self):
759        class CallableClass:
760            def __call__():
761                pass
762
763        class NotCallableClass:
764            pass
765
766        @torch.compile(backend="eager", fullgraph=True)
767        def fn1(x, arg):
768            if callable(arg):
769                return x
770            return x + 1
771
772        @torch.compile(backend="eager", fullgraph=True)
773        def fn2(x, arg):
774            if callable(arg):
775                return x * 2
776            return x + 1
777
778        input = torch.randn(4)
779
780        for f in [fn1, fn2]:
781            self.assertEqual(f(input, NotCallableClass()), input + 1)
782            self.assertEqual(
783                f(input, CallableClass()), input if f is fn1 else input * 2
784            )
785
786            # passing tensor and scalars
787            self.assertEqual(f(input, 1), input + 1)
788            self.assertEqual(f(input, 1.1), input + 1)
789            self.assertEqual(f(input, True), input + 1)
790            self.assertEqual(f(input, input), input + 1)
791
792    def test_callable_list(self):
793        @torch.compile(backend="eager", fullgraph=True)
794        def fn(x, arg):
795            if callable(arg):
796                return x
797            return x + 1
798
799        input = torch.randn(4)
800        self.assertEqual(fn(input, [1, 2, 3]), input + 1)
801        self.assertEqual(fn(input, (1, 2, 3)), input + 1)
802
803    @make_test
804    def test_len_constant_misc_iterables(x):
805        a = len((1, 2, 3))
806        b = len("test str")
807        c = a + b
808        return torch.add(x, c)
809
810    @make_test
811    def test_dict_kwargs(x):
812        z = dict(text_embed=x + 1, other=x + 2)
813        return z
814
815    @make_test
816    def test_ordered_dict_kwargs(x):
817        z = collections.OrderedDict(sample=torch.ones(10))
818        return z
819
820    @make_test
821    def test_custom_dict_kwargs(x):
822        z = CustomDictSubclass(sample=torch.ones(10))
823        return z
824
825    @make_test
826    def test_float(x):
827        y = float(1.2)  # noqa: UP018
828        y += float("1.2")
829        return torch.add(x, y)
830
831    @make_test
832    def test_is_floating_point(x):
833        y = x + 1
834        return torch.is_floating_point(y), torch.is_floating_point(input=y)
835
836    @make_test
837    def test_dtype(x):
838        if x.dtype == torch.float32:
839            return x + 1
840
841    @make_test
842    def test_get_default_dtype(x):
843        if x.dtype == torch.get_default_dtype():
844            return x + 1
845        else:
846            return x - 1
847
848    @make_test
849    def test_get_autocast_gpu_dtype(x):
850        dtype = torch.get_autocast_gpu_dtype()
851        return x.type(dtype)
852
853    @make_test
854    def test_is_any_autocast_enabled(x):
855        if torch._C._is_any_autocast_enabled():
856            return x + 1
857        else:
858            return x - 1
859
860    @make_test
861    def test_is_checkpoint_valid(x):
862        if torch.autograd._is_checkpoint_valid():
863            return x + 1
864        else:
865            return x - 1
866
867    @make_test
868    def test_list_compare_polyfill(x):
869        for a, b, c in [
870            [(1, 2, 3), (1, 2, 3), 7.77],
871            [(1, 4, 3), (1, 2, 3), 3.33],
872            [(1, 2), (1, 2, 3), 5.55],
873            [(1, 2, 3), (1, 2), 11.11],
874            [(1, -1, 3), (1, 2, 3), 13.33],
875        ]:
876            if a != b:
877                x += 1 * c
878            if a == b:
879                x += 2 * c
880            if a < b:
881                x += 4 * c
882            if a > b:
883                x += 8 * c
884            if a <= b:
885                x += 16 * c
886            if a >= b:
887                x += 32 * c
888        return x
889
890    @make_test
891    def test_promote_types(x):
892        if x.dtype == torch.promote_types(torch.int32, torch.float32):
893            return x + 1
894        else:
895            return x - 1
896
897    @make_test
898    def test_cublas_allow_tf32(x):
899        if torch.backends.cuda.matmul.allow_tf32:
900            return x.sin() + 1
901
902        return x.cos() - 1
903
904    @make_test
905    def test_get_calculate_correct_fan(x):
906        fan_in = torch.nn.init._calculate_correct_fan(x, "fan_in")
907        return x + fan_in
908
909    @make_test
910    def test_is_complex(x):
911        if torch.is_complex(x):
912            return x + 1
913        else:
914            return x - 1
915
916    @make_test
917    def test_tensor_is_complex(x):
918        if x.is_complex():
919            return x + 1
920        else:
921            return x - 1
922
923    @make_test
924    def test_get_privateuse1_name(x):
925        if torch._C._get_privateuse1_backend_name() == "privateuseone":
926            return x + 1
927        else:
928            return x - 1
929
930    @make_test
931    def test_device(x):
932        if not x.is_cuda:
933            return x + 1
934
935    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
936    @make_test
937    def test_get_device_properties_tensor_device(a):
938        x = a.to("cuda")
939        prop = torch.cuda.get_device_properties(x.device)
940        if prop.major == 8:
941            return x + prop.multi_processor_count
942        return x + prop.max_threads_per_multi_processor
943
944    @make_test
945    def test_tensor_type(a, b):
946        m = a.to(torch.float16)
947        return b.type(m.type())
948
949    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
950    @make_test
951    def test_tensor_type2(a, b):
952        m = a.to("cuda")
953        return m + b.type(m.type())
954
955    @make_test
956    def test_tensor_type3(a, b):
957        m = a.type(torch.HalfTensor)
958        return b.type(m.type())
959
960    @make_test
961    def test_tensor_type4(a, b):
962        m = a.type("torch.HalfTensor")
963        return b.type(m.type())
964
965    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
966    @make_test
967    def test_tensor_type5(a, b):
968        m = a.type(torch.cuda.HalfTensor)
969        return b.type(m.type())
970
971    @make_test
972    def test_tensor_element_size(a):
973        if a.element_size() > 1:
974            return (a + a.element_size(), a - a.element_size())
975        return (a - a.element_size(), a + a.element_size())
976
977    @make_test
978    def test_ndim(x):
979        if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2:
980            return x + 1
981
982    @make_test
983    def test_T(x):
984        return torch.ones_like(x.T)
985
986    @make_test
987    def test_mT(x):
988        return torch.ones_like(x.mT)
989
990    @make_test
991    def test_is_sparse(x):
992        if not x.is_sparse:
993            return x + 1
994
995    @make_test
996    def test_shape1(x):
997        if x.shape[0] == 10:
998            return x + 1
999
1000    @make_test
1001    def test_shape2(x):
1002        if x.size(1) == 10:
1003            return x + 1
1004
1005    @make_test
1006    def test_del(a, b):
1007        c = a + 1
1008        d = c + 2
1009        del c, a
1010        return b + d
1011
1012    @make_test
1013    def test_chunks1(x):
1014        chunk_size = 5
1015        assert x.shape[0] % chunk_size == 0
1016        assert x.shape[0] // chunk_size == 2
1017        return x[:chunk_size] - x[chunk_size:]
1018
1019    @make_test
1020    def test_import1(x, y):
1021        import torch
1022        from torch import sub
1023
1024        return sub(torch.add(x, y), y)
1025
1026    @make_test
1027    def test_return_dict(x, y):
1028        z = [x + y, y, False]
1029        return {"x": x, "z": z, "a": x, "b": z, "c": x}
1030
1031    @make_test
1032    def test_return_dict2(x, y):
1033        tmp = {"x": x}
1034        tmp["z"] = [x + y, y]
1035        tmp["y"] = y
1036        tmp["z"].append(False)
1037        return tmp
1038
1039    @make_test
1040    def test_funcdef_closure(x, y):
1041        x = x + y + 1.0
1042
1043        def inner(z):
1044            nonlocal x, y
1045            y = x + z + 20.0
1046            x = y + z + 10.0
1047
1048        inner(2.0)
1049        inner(3.0)
1050
1051        return x, y
1052
1053    @make_test
1054    def test_module_constant(x, y):
1055        r = x + y
1056        for i in range(torch._dynamo.testing.three):
1057            r = r / y
1058        return r
1059
1060    @make_test
1061    def test_inline_softmax(x, y):
1062        # This is common in sme huggingface models
1063        return torch.nn.Softmax(dim=-1)(x + y * 2)
1064
1065    @make_test
1066    def test_dtype_compare(a, b):
1067        if a.dtype == torch.float16:
1068            return a + 10
1069        if a.dtype == torch.float32:
1070            return a - b * 32
1071
1072    @make_test
1073    def test_build_list_unpack(a, b):
1074        it1 = (x + 1 for x in (a, b))
1075        it2 = (x - 1 for x in (a, b))
1076        return torch.cat([*it1, *it2], dim=-1)
1077
1078    @make_test
1079    def test_tensor_len(a, b):
1080        return a + b + len(a) + b.__len__()
1081
1082    @make_test
1083    def test_pop(a, b):
1084        ll = [a, b]
1085        ll.append(a + 1)
1086        ll.extend(
1087            [
1088                b + 2,
1089                a + b,
1090            ]
1091        )
1092        ll.pop(-1)
1093        ll.pop(0)
1094        ll.pop()
1095        v1, v2 = ll
1096        return v1 - v2
1097
1098    @make_test
1099    def test_list_convert(a, b):
1100        ll = [a + 2, b]
1101        ll = tuple(ll)
1102        tmp = b + 3
1103        ll = list(ll)
1104        v1, v2 = ll
1105        return v1 - v2 + tmp
1106
1107    @make_test
1108    def test_list_add(a, b):
1109        l1 = (a, b)
1110        l2 = ()  # being a LOAD_CONST in the bytecode
1111        l3 = l1 + l2
1112        return l3[0] + l3[1]
1113
1114    @make_test
1115    def test_list_index_with_constant_tensor(a, b):
1116        l1 = [a, b, a + 1, b + 1]
1117        return l1[torch.as_tensor(2)]
1118
1119    @make_test
1120    def test_startswith(a, b):
1121        x = a + b
1122        if "foobar".startswith("foo") and "test" in constant3.__module__:
1123            x = x + 1
1124        return x
1125
1126    @make_test
1127    def test_dict_ops(a, b):
1128        tmp = {"a": a + 1, "b": b + 2}
1129        assert tmp.get("zzz") is None
1130        v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4)
1131        tmp.update({"d": 3})
1132        tmp["c"] = v + tmp["d"]
1133        if "c" in tmp and "missing" not in tmp:
1134            return tmp["c"] - tmp["a"] + len(tmp)
1135
1136    @make_test
1137    def test_inline_jit__unwrap_optional(x):
1138        if torch.jit._unwrap_optional(x) is None:
1139            return torch.ones(2, 2)
1140        return x.sin()
1141
1142    @make_test
1143    def test_zip_longest(x):
1144        list1 = [1, 2, 3]
1145        list2 = ["a", "b"]
1146        list3 = [True, False, True, False]
1147        return torch.sin(x + 1), list(
1148            itertools.zip_longest(list1, list2, list3, fillvalue=None)
1149        )
1150
1151    def test_torch_size_as_dict_key(self):
1152        def fn(x, cached):
1153            if x.shape not in cached:
1154                cached[x.shape] = x
1155            return x + cached[x.shape]
1156
1157        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1158        x1 = torch.randn(2, 3)
1159        x2 = torch.randn(2, 3)
1160        cached = {}
1161        ref1 = fn(x1, cached)
1162        ref2 = fn(x2, cached)
1163        cached = {}
1164        res1 = opt_fn(x1, cached)
1165        res2 = opt_fn(x2, cached)
1166        self.assertEqual(ref1, res1)
1167        self.assertEqual(ref2, res2)
1168
1169    def test_dict_param_keys(self):
1170        a_param = torch.nn.Parameter(torch.ones([4, 4]))
1171
1172        def fn(a):
1173            tmp = {"a": a, a_param: 3}
1174            return tmp["a"] + tmp[a_param]
1175
1176        test = make_test(fn)
1177        test(self)
1178
1179    def test_dict_mutable_map(self):
1180        from collections.abc import MutableMapping
1181
1182        class TensorDict(MutableMapping):
1183            def __init__(self) -> None:
1184                self._dict = {}
1185
1186            def add(self, key, value):
1187                self._dict[key] = value
1188
1189            def items(self):
1190                return self._dict.items()
1191
1192            def __delitem__(self, key):
1193                del self._dict[key]
1194
1195            def __getitem__(self, key):
1196                return self._dict[key]
1197
1198            def __iter__(self):
1199                return iter(self._dict)
1200
1201            def __len__(self):
1202                return len(self._dict)
1203
1204            def __setitem__(self, key, value):
1205                self._dict[key] = value
1206
1207        tensor_dict = TensorDict()
1208        tensor_dict.add("a", torch.ones(4) * 2)
1209
1210        def fn(x):
1211            copy_tensordict = dict(tensor_dict)
1212            return x * copy_tensordict["a"]
1213
1214        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1215        x = torch.randn(4)
1216
1217        ref = fn(x)
1218        res = opt_fn(x)
1219        self.assertEqual(ref, res)
1220
1221    def test_unpack_mutable_map(self):
1222        from collections.abc import MutableMapping
1223
1224        class TensorDict(MutableMapping):
1225            def __init__(self) -> None:
1226                self._dict = {}
1227
1228            def add(self, key, value):
1229                self._dict[key] = value
1230
1231            def items(self):
1232                return self._dict.items()
1233
1234            def __delitem__(self, key):
1235                del self._dict[key]
1236
1237            def __getitem__(self, key):
1238                return self._dict[key]
1239
1240            def __iter__(self):
1241                return iter(self._dict)
1242
1243            def __len__(self):
1244                return len(self._dict)
1245
1246            def __setitem__(self, key, value):
1247                self._dict[key] = value
1248
1249        tensor_dict = TensorDict()
1250        tensor_dict.add("a", torch.ones(4) * 2)
1251
1252        def gn(x, a=1):
1253            return x * a
1254
1255        def fn(x):
1256            return gn(x, **tensor_dict)
1257
1258        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1259
1260        x = torch.randn(4)
1261
1262        ref = fn(x)
1263        res = opt_fn(x)
1264        self.assertEqual(ref, res)
1265
1266    def _test_default_dict_helper(self, factory):
1267        dd = collections.defaultdict(factory)
1268        param = torch.nn.Parameter(torch.ones([2, 2]))
1269
1270        def fn(x):
1271            dd["a"] = x + 1
1272            dd[param] = 123
1273            dd["c"] = x * 2
1274            return dd["b"], dd
1275
1276        x = torch.randn(10, 10)
1277        ref = fn(x)
1278        opt_fn = torch._dynamo.optimize_assert("eager")(fn)
1279        res = opt_fn(x)
1280
1281        self.assertTrue(same(ref[0], res[0]))
1282        self.assertTrue(same(ref[1]["a"], res[1]["a"]))
1283        self.assertTrue(same(ref[1]["c"], res[1]["c"]))
1284        self.assertTrue(same(ref[1][param], res[1][param]))
1285
1286    def test_default_dict_dict(self):
1287        self._test_default_dict_helper(dict)
1288
1289    def test_default_dict_list(self):
1290        self._test_default_dict_helper(list)
1291
1292    def test_default_dict_tuple(self):
1293        self._test_default_dict_helper(tuple)
1294
1295    def test_default_dict_set(self):
1296        self._test_default_dict_helper(set)
1297
1298    def test_default_dict_lambda(self):
1299        self._test_default_dict_helper(lambda: dict())  # noqa: C408
1300
1301    def test_default_dict_closure(self):
1302        def factory():
1303            return dict()  # noqa: C408
1304
1305        self._test_default_dict_helper(factory)
1306
1307    def test_class_dict(self):
1308        class A:
1309            x = 4
1310            y = 5
1311
1312            def __init__(self) -> None:
1313                self.a = 6
1314
1315        a = A()
1316
1317        def fn(x):
1318            if "x" in type(a).__dict__:
1319                return x + 1
1320            return x + 2
1321
1322        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1323        x = torch.randn(4)
1324        self.assertEqual(fn(x), opt_fn(x))
1325
1326    def test_default_dict_constr(self):
1327        param = torch.nn.Parameter(torch.ones([2, 2]))
1328
1329        def fn(x):
1330            dd = collections.defaultdict(lambda: dict())  # noqa: C408
1331            dd["a"] = x + 1
1332            dd[param] = 123
1333            dd["c"] = x * 2
1334            dd.update({"b": x * 3})
1335            dd.update([["d", x - 2], ("e", x + 2)])
1336            dd.update(zip("ab", [x + 3, x + 4]))
1337            return dd["b"], dd
1338
1339        x = torch.randn(10, 10)
1340        ref = fn(x)
1341        opt_fn = torch._dynamo.optimize_assert("eager")(fn)
1342        res = opt_fn(x)
1343
1344        self.assertTrue(same(ref[0], res[0]))
1345        self.assertTrue(same(ref[1]["a"], res[1]["a"]))
1346        self.assertTrue(same(ref[1]["b"], res[1]["b"]))
1347        self.assertTrue(same(ref[1]["c"], res[1]["c"]))
1348        self.assertTrue(same(ref[1]["d"], res[1]["d"]))
1349        self.assertTrue(same(ref[1]["e"], res[1]["e"]))
1350        self.assertTrue(same(ref[1][param], res[1][param]))
1351
1352    def test_dict_tuple_lazy_guard(self):
1353        @torch.compile(backend="eager")
1354        def fn(x, y):
1355            return torch.sin(x) * y[1]
1356
1357        fn(torch.randn(3), {1: 1, 2: 2})
1358        # Changing the value of other key should not causing recompilation
1359        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
1360            fn(torch.randn(3), {1: 1, 2: 3})
1361
1362        fn(torch.randn(3), (1, 2, 3))
1363        # Changing the value of index 0, 2 (not 1) should not cause recompilation
1364        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
1365            fn(torch.randn(3), (11, 2, 13))
1366
1367    @make_test
1368    def test_call_dict1(x):
1369        d1 = dict()  # noqa: C408
1370        d1["x"] = x + 1
1371        d2 = collections.OrderedDict()
1372        d2["x"] = x + 2
1373        return d1["x"] + d2["x"] + 1
1374
1375    @make_test
1376    def test_call_dict2(x):
1377        d1 = dict()  # noqa: C408
1378        d1["x"] = x
1379        d2 = collections.OrderedDict(d1)
1380        if isinstance(d2, collections.OrderedDict):
1381            return x + 1
1382        else:
1383            return x - 1
1384
1385    @make_test
1386    def test_call_dict3(x):
1387        my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
1388        d1 = dict(my_list)
1389        d1["a"] = x + 10
1390        d2 = collections.OrderedDict(my_list)
1391        d2["c"] = x + 20
1392        return d1["a"] + d2["c"] + 1
1393
1394    @make_test
1395    def test_call_dict4(x):
1396        my_list = (("a", x), ("b", x + 1), ("c", x + 2))
1397        d1 = dict(my_list)
1398        d1["a"] = x + 10
1399        d2 = collections.OrderedDict(my_list)
1400        d2["c"] = x + 20
1401        return d1["a"] + d2["c"] + 1
1402
1403    @make_test
1404    def test_call_dict5(x):
1405        my_list = iter([("a", x), ("b", x + 1), ("c", x + 2)])
1406        d1 = dict(my_list)
1407        d1["a"] = x + 10
1408        d2 = collections.OrderedDict(my_list)
1409        d2["c"] = x + 20
1410        return d1["a"] + d2["c"] + 1
1411
1412    @make_test
1413    def test_dict_fromkeys(x, y):
1414        lst = ["a", "b"]
1415        d = dict.fromkeys(lst)
1416        d1 = dict.fromkeys(d, x + 1)
1417        d2 = collections.defaultdict.fromkeys(iter(d1), x - 2)
1418        d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y)
1419        return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1
1420
1421    @make_test
1422    def test_dict_copy(x):
1423        my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
1424        d1 = dict(my_list)
1425        d1["a"] = x + 10
1426        d2 = d1.copy()
1427        d2["a"] = x - 5
1428        d2["b"] = x + 3
1429        d3 = collections.OrderedDict(my_list)
1430        d3["c"] = x + 20
1431        d4 = d3.copy()
1432        d4["c"] = x - 10
1433        return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1
1434
1435    @make_test
1436    def test_dict_update(x, y, z):
1437        d = {"a": x, "b": y}
1438        d.update({"a": y - 1})
1439        d.update([("b", z + 1), ["c", z]])
1440        d.update(zip("ab", [z + 3, y + 2]))
1441
1442        od = collections.OrderedDict(a=x * 3, b=y + 2)
1443        od.update({"a": y + 5})
1444        od.update([["b", z + 6], ("c", z - 7)])
1445        od.update(zip("ab", [z - 3, x + 2]))
1446        return d["a"] * od["a"] + od["c"] + d["b"] + od["b"] * d["c"]
1447
1448    @make_test
1449    def test_min_max(a, b):
1450        c = a + b
1451        a = a.sum()
1452        b = b.sum()
1453        a = min(max(a, 0), 1)
1454        b = max(0, min(1, b))
1455        return max(a, b) - min(a, b) + c
1456
1457    @make_test
1458    def test_symbool_to_int(x):
1459        # this is roughly the pattern found in einops.unpack()
1460        if sum(s == -1 for s in x.size()) == 0:
1461            return x + 1
1462        else:
1463            return x - 1
1464
1465    @make_test
1466    def test_map_sum(a, b, c, d):
1467        return sum(map(lambda x: x + 1, [a, b, c, d]))
1468
1469    @make_test
1470    def test_sum(a, b, c, d):
1471        return sum([a, b, c, d])
1472
1473    @make_test
1474    def test_sum_with_start_arg(a, b, c, d):
1475        return sum([b, c, d], a)
1476
1477    @make_test
1478    def test_sum_with_start_kwarg(a, b, c, d):
1479        return sum([b, c, d], start=a)
1480
1481    @make_test(expected_frame_count=0)
1482    def test_sum_shortcut():
1483        return sum([0, 1.0, 2, 3.0])
1484
1485    @make_test(expected_frame_count=0)
1486    def test_sum_shortcut_with_start_arg():
1487        return sum([0, 1.0, 2, 3.0], -10)
1488
1489    @make_test(expected_frame_count=0)
1490    def test_sum_shortcut_with_start_kwarg():
1491        return sum([0, 1.0, 2, 3.0], start=-10)
1492
1493    @make_test
1494    def test_reduce(a, b, c, d):
1495        return functools.reduce(operator.add, [a, b, c, d])
1496
1497    @make_test
1498    def test_reduce_with_initial(a, b, c, d):
1499        return functools.reduce(operator.add, [b, c, d], a)
1500
1501    @make_test(expected_frame_count=0)
1502    def test_reduce_with_single(x):
1503        return functools.reduce(lambda a, b: (a, b), [x])
1504
1505    @make_test(expected_frame_count=0)
1506    def test_reduce_with_single_with_initial(x, y):
1507        return functools.reduce(lambda a, b: (a, b), [y], x)
1508
1509    @make_test(expected_frame_count=0)
1510    def test_reduce_with_none_initial(x):
1511        return functools.reduce(lambda a, b: (a, b), [x], None)
1512
1513    @make_test
1514    def test_tuple_contains(a, b):
1515        v1 = "a"
1516        v2 = "b"
1517        v3 = "c"
1518        vals1 = (v1, v2, v3)
1519        vals2 = ("d", "e", "f")
1520        if "a" in vals1 and "b" not in vals2:
1521            return a + b
1522        return a - b
1523
1524    @unittest.skipIf(
1525        sys.version_info < (3, 9),
1526        "SET_UPDATE was added at Python 3.9",
1527    )
1528    @make_test
1529    def test_set_update_bytecode(x):
1530        # This produces bytecode SET_UPDATE since python 3.9
1531        var = {"apple", "banana", "cherry"}
1532        if isinstance(var, set):
1533            return x + 1
1534        else:
1535            return x - 1
1536
1537    @unittest.skipIf(
1538        sys.version_info < (3, 9),
1539        "SET_UPDATE was added at Python 3.9",
1540    )
1541    @make_test
1542    def test_set_update_list_with_duplicated_items(x):
1543        list1 = ["apple", "banana", "apple"]
1544        list2 = ["orange", "banana"]
1545        if len({*list1, *list2}) == 3:
1546            return x + 1
1547        else:
1548            return x - 1
1549
1550    @make_test
1551    def test_set_contains(a, b):
1552        vals = set(["a", "b", "c"])
1553        if "a" in vals:
1554            x = a + b
1555        else:
1556            x = a - b
1557        if "d" in vals:
1558            y = a + b
1559        else:
1560            y = a - b
1561        return x, y
1562
1563    def test_set_isdisjoint(self):
1564        x = {"apple", "banana", "cherry"}
1565        y = {"google", "microsoft", "apple"}
1566
1567        def fn(a):
1568            if x.isdisjoint(y):
1569                return a + 1
1570            else:
1571                return a - 1
1572
1573        test = make_test(fn)
1574        test(self)
1575
1576    @make_test
1577    def test_set_intersection(a, b):
1578        set1 = {"apple", "banana", "cherry"}
1579        set2 = {"google", "microsoft", "apple"}
1580        intersection_set = set1.intersection(set2)
1581        if "apple" in intersection_set:
1582            x = a + b
1583        else:
1584            x = a - b
1585        if "banana" in intersection_set:
1586            y = a + b
1587        else:
1588            y = a - b
1589        return x, y
1590
1591    @make_test
1592    def test_set_union(a, b):
1593        set1 = {"apple", "banana", "cherry"}
1594        set2 = {"google", "microsoft", "apple"}
1595        union_set = set1.union(set2)
1596        if "apple" in union_set:
1597            x = a + b
1598        else:
1599            x = a - b
1600        if "banana" in union_set:
1601            y = a + b
1602        else:
1603            y = a - b
1604        return x, y
1605
1606    @make_test
1607    def test_set_difference(a, b):
1608        set1 = {"apple", "banana", "cherry"}
1609        set2 = {"google", "microsoft", "apple"}
1610        difference_set = set1.difference(set2)
1611        if "apple" in difference_set:
1612            x = a + b
1613        else:
1614            x = a - b
1615        if "banana" in difference_set:
1616            y = a + b
1617        else:
1618            y = a - b
1619        return x, y
1620
1621    def test_set_keys_view(self):
1622        from collections.abc import KeysView
1623
1624        class StringKeys(KeysView):
1625            def __init__(self, keys):
1626                self.keys = keys
1627
1628            def __getitem__(self, key):
1629                return self.keys.__getitem__(key)
1630
1631            def __iter__(self):
1632                yield from self.keys
1633
1634            def __repr__(self):
1635                return f"{type(self).__name__}({self.keys})"
1636
1637            def __len__(self):
1638                return len(self.keys)
1639
1640            def __contains__(self, item):
1641                return self.keys.__contains__(item)
1642
1643        a = StringKeys([1, 2, 3, 3])
1644
1645        def fn(x):
1646            set_a = set(a)
1647            return len(set_a) * x
1648
1649        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1650        x = torch.rand(4)
1651        self.assertEqual(fn(x), opt_fn(x))
1652
1653    def test_constant_set(self):
1654        s = set([1, 2])
1655
1656        def fn(x):
1657            return torch.cos(x) * len(s)
1658
1659        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1660
1661        x = torch.rand(4)
1662        self.assertEqual(fn(x), opt_fn(x))
1663
1664        # This should cause recompilation
1665        s.add(3)
1666        self.assertEqual(fn(x), opt_fn(x))
1667
1668    def test_set_add(self):
1669        s = set([1, 2])
1670
1671        def fn(x):
1672            s.add(3)
1673            return torch.cos(x) * len(x)
1674
1675        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1676
1677        x = torch.rand(4)
1678        self.assertEqual(fn(x), opt_fn(x))
1679        self.assertEqual(len(s), 3)
1680
1681    @make_test
1682    def test_tuple_iadd(a, b):
1683        output = (a, b)
1684        output += (a + b, a - b)
1685        return output
1686
1687    @make_test
1688    def test_unpack_ex1(x):
1689        output = (x, x + 1, x + 2, x + 3)
1690        a, b, *cd = output
1691        return a - b / cd[0]
1692
1693    @make_test
1694    def test_unpack_ex2(x):
1695        output = (x, x + 1, x + 2, x + 3)
1696        *ab, c, d = output
1697        return c - d / ab[0]
1698
1699    @make_test
1700    def test_unpack_ex3(x):
1701        output = (x, x + 1, x + 2, x + 3)
1702        a, *bc, d = output
1703        return a - d / bc[0]
1704
1705    @make_test
1706    def test_const_tuple_add1(x):
1707        output = (x, x + 1, x + 2, x + 3)
1708        output = () + output + ()
1709        return output[2] + output[3]
1710
1711    @make_test
1712    def test_const_tuple_add2(x):
1713        output = (x, x + 1, x + 2, x + 3)
1714        output = (None,) + output + (None,)
1715        return output[2] + output[3]
1716
1717    @make_test
1718    def test_list_truth(a, b):
1719        tmp = [1, 2, 3]
1720        if tmp:
1721            return a + b
1722        else:
1723            return a - b
1724
1725    @make_test
1726    def test_list_reversed(a, b):
1727        tmp = [a + 1, a + 2, a + 3]
1728        return a + b + next(iter(reversed(tmp)))
1729
1730    @make_test
1731    def test_list_sorted1(x):
1732        tmp = [1, 10, 3, 0]
1733        return x + 1, sorted(tmp), sorted(tmp, reverse=True)
1734
1735    @make_test
1736    def test_list_sorted2(x):
1737        y = [
1738            ("john", "A", 8),
1739            ("jane", "B", 5),
1740            ("dave", "B", 10),
1741        ]
1742        return (
1743            x + 1,
1744            sorted(y),
1745            sorted(y, key=lambda student: student[2]),
1746            sorted(y, key=lambda student: student[2], reverse=True),
1747        )
1748
1749    @make_test
1750    def test_tuple_sorted(x):
1751        tmp = (1, 10, 3, 0)
1752        return x + 1, sorted(tmp), sorted(tmp, reverse=True)
1753
1754    @make_test
1755    def test_dict_sorted(x):
1756        tmp = {1: "D", 10: "B", 3: "E", 0: "F"}
1757        return x + 1, sorted(tmp), sorted(tmp, reverse=True)
1758
1759    def test_dict_hasattr(self):
1760        def fn(x):
1761            if hasattr(x, "to"):
1762                return x.to("cpu")
1763            if hasattr(x, "items"):
1764                return torch.cos(x["a"])
1765            return x
1766
1767        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1768
1769        x = dict(a=torch.randn(3))
1770        self.assertEqual(fn(x), opt_fn(x))
1771
1772        x = torch.randn(4)
1773        self.assertEqual(fn(x), opt_fn(x))
1774
1775    @make_test
1776    def test_list_clear(a, b):
1777        tmp = [a + 1, a + 2]
1778        tmp.clear()
1779        tmp.append(a + b)
1780        return tmp
1781
1782    @make_test
1783    def test_not_list(a):
1784        return not [a + 1]
1785
1786    @make_test
1787    def test_islice_chain(a, b):
1788        tmp1 = [a + 1, a + 2]
1789        tmp2 = [a + 3, a + 4]
1790        a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3))
1791        c = next(itertools.islice(tmp1, 1, None))
1792        return a - b / c
1793
1794    @make_test
1795    def test_namedtuple(a, b):
1796        mytuple = collections.namedtuple("mytuple", ["x", "y", "xy"])
1797        tmp = mytuple(a, b, a + b)
1798        return mytuple(tmp.x, tmp[1], tmp.xy + b)
1799
1800    @make_test
1801    def test_namedtuple_defaults(a, b):
1802        mytuple = collections.namedtuple(
1803            "mytuple", ["x", "y", "xy"], defaults=(None, 1, None)
1804        )
1805        tmp = mytuple(a, xy=b)
1806        return mytuple(tmp.x, tmp[1], tmp.xy + b)
1807
1808    class MyNamedTuple(NamedTuple):
1809        first: torch.Tensor
1810        second: torch.Tensor
1811
1812        def add(self) -> torch.Tensor:
1813            return self.first + self.second
1814
1815        @staticmethod
1816        def static_method() -> int:
1817            return 1
1818
1819        @classmethod
1820        def class_method(cls) -> str:
1821            return cls.__name__
1822
1823    @make_test
1824    def test_namedtuple_user_methods(a, b):
1825        mytuple = FunctionTests.MyNamedTuple(a, b)
1826        return mytuple.add(), mytuple.static_method(), mytuple.class_method()
1827
1828    @make_test
1829    def test_namedtuple_hasattr(a, b):
1830        mytuple = FunctionTests.MyNamedTuple(a, b)
1831
1832        def isinstance_namedtuple(obj) -> bool:
1833            return (
1834                isinstance(obj, tuple)
1835                and hasattr(obj, "_asdict")
1836                and hasattr(obj, "_fields")
1837            )
1838
1839        if isinstance_namedtuple(mytuple):
1840            return a + b
1841        else:
1842            return a - b
1843
1844    @make_test
1845    def test_torch_size_hasattr(x):
1846        if hasattr(x.shape, "_fields"):
1847            return x + 1
1848        else:
1849            return x - 1
1850
1851    @make_test
1852    def test_is_quantized(a, b):
1853        if not a.is_quantized:
1854            return a + b
1855
1856    @make_test
1857    def test_fstrings1(a, b):
1858        x = 1.229
1859        tmp = f"{x:.2f} bar"
1860        if tmp.startswith("1.23"):
1861            return a + b
1862
1863    @make_test
1864    def test_fstrings2(x):
1865        tmp = f"{x.shape[0]} bar"
1866        if tmp.startswith("10"):
1867            return x + 1
1868
1869    @make_test
1870    def test_fstrings3(x):
1871        tmp = f"{x.__class__.__name__} foo"
1872        if tmp.startswith("Tensor"):
1873            return x + 1
1874
1875    @make_test
1876    def test_fstrings4(x):
1877        tmp = f"{x.shape[0]} bar"
1878        if "10" in tmp:
1879            return x + 1
1880
1881    @make_test
1882    def test_fstrings5(x):
1883        tmp = f"{x.shape[0]} bar"
1884        if "10" in (tmp + "haha"):
1885            return x + 1
1886
1887    @make_test
1888    def test_fstrings6(x):
1889        tmp = f"{x.shape[0] + x.shape[1]}"
1890        if "20" in tmp:
1891            return x + 1
1892
1893    @make_test
1894    def test_tensor_new_with_size(x):
1895        y = torch.rand(5, 8)
1896        z = x.new(y.size())
1897        assert z.size() == y.size()
1898
1899    @make_test
1900    def test_tensor_new_with_shape(x):
1901        y = torch.rand(5, 8)
1902        z = x.new(y.shape)
1903        assert z.size() == y.size()
1904
1905    @make_test
1906    def test_jit_annotate(x):
1907        y = torch.jit.annotate(Any, x + 1)
1908        return y + 2
1909
1910    @make_test
1911    def test_is_contiguous_memory_format(tensor):
1912        if torch.jit.is_scripting():
1913            return None
1914        elif tensor.is_contiguous(memory_format=torch.contiguous_format):
1915            return tensor + 1
1916
1917    def test_is_contiguous_frame_counts(self):
1918        data = [
1919            torch.rand(10),
1920            torch.rand(2, 3, 32, 32),
1921            torch.rand(2, 3, 32, 32).contiguous(memory_format=torch.channels_last),
1922            torch.rand(10)[::2],
1923            torch.rand(12),
1924            torch.rand(2, 3, 24, 24).contiguous(memory_format=torch.channels_last),
1925            torch.rand(50)[::2],
1926            torch.rand(2, 3, 32, 32)[:, :, 2:-2, 3:-3],
1927        ]
1928        # dynamo should recompile for all inputs in static shapes mode
1929        expected_frame_counts_static = [1, 2, 3, 4, 5, 6, 7, 8]
1930        # dynamo should recompile for items 0, 1, 2, 6 in dynamic shapes mode
1931        expected_frame_counts_dynamic = [1, 2, 3, 4, 4, 4, 4, 5]
1932        expected_frame_counts = ifdynstaticdefault(
1933            expected_frame_counts_static, expected_frame_counts_dynamic
1934        )
1935        dynamic = ifdynstaticdefault(False, True)
1936
1937        def func(x):
1938            if x.is_contiguous():
1939                return x + 1
1940            elif x.is_contiguous(memory_format=torch.channels_last):
1941                return x + 2
1942            else:
1943                return x + 3
1944
1945        cnt = torch._dynamo.testing.CompileCounter()
1946        cfunc = torch._dynamo.optimize_assert(cnt, dynamic=dynamic)(func)
1947
1948        assert cnt.frame_count == 0
1949        for i, x in enumerate(data):
1950            expected = func(x)
1951            output = cfunc(x)
1952            self.assertTrue(same(output, expected))
1953            assert cnt.frame_count == expected_frame_counts[i]
1954
1955    @make_test
1956    def test_list_slice_assignment(x):
1957        m = [1, 2, 3, 4]
1958        m[1:] = [6] * (len(m) - 1)
1959        return x + 1
1960
1961    @make_test
1962    def test_distributed_is_available(x):
1963        if torch.distributed.is_available():
1964            return x + 1
1965        else:
1966            return x - 1
1967
1968    @unittest.skipIf(
1969        not torch.distributed.is_available(), "requires distributed package"
1970    )
1971    @make_test
1972    def test_distributed_is_initialized(x):
1973        if torch.distributed.is_initialized():
1974            return x + 1
1975        else:
1976            return x - 1
1977
1978    @disable_translation_validation_if_dynamic_shapes
1979    @make_test
1980    def test_torch_distributions_functions(x):
1981        normal = torch.distributions.Normal(x, torch.tensor(1))
1982        independent = torch.distributions.Independent(normal, 1)
1983        return independent.log_prob(x)
1984
1985    @make_test
1986    def test_context_wrapping_nested_functions_no_closure(x):
1987        @torch.no_grad()
1988        def augment(x: torch.Tensor) -> torch.Tensor:
1989            return (x + 1) * 2
1990
1991        return augment(x)
1992
1993    # # This is to test the new syntax for pattern matching
1994    # # ("match ... case ...") added on python 3.10.
1995    # # Uncomment these test cases if you run on 3.10+
1996    # @make_test
1997    # def test_match_sequence(a):
1998    #     point = (5, 8)
1999    #     match point:
2000    #         case (0, 0):
2001    #             return a
2002    #         case (0, y):
2003    #             return a - y
2004    #         case (x, 0):
2005    #             return a + x
2006    #         case (x, y):
2007    #             return a + x - y
2008
2009    # @make_test
2010    # def test_match_mapping_and_match_keys(x):
2011    #     param = {"a": 0.5}
2012    #     match param:
2013    #         case {"a": param}:
2014    #             return x * param
2015    #         case {"b": param}:
2016    #             return x / param
2017
2018    def test_math_radians(self):
2019        def func(x, a):
2020            return x + math.radians(a)
2021
2022        cnt = torch._dynamo.testing.CompileCounter()
2023        cfunc = torch._dynamo.optimize_assert(cnt)(func)
2024
2025        assert cnt.frame_count == 0
2026        x = torch.rand(10)
2027        expected = func(x, 12)
2028        output = cfunc(x, 12)
2029        self.assertTrue(same(output, expected))
2030        assert cnt.frame_count == 1
2031
2032    @make_test
2033    def test_numpy_meshgrid(x, y):
2034        r1, r2 = np.meshgrid(x.numpy(), y.numpy())
2035        return torch.from_numpy(r1), torch.from_numpy(r2)
2036
2037    @make_test
2038    def test_torch_from_numpy(x):
2039        a = x.numpy()
2040        b = torch.from_numpy(a)
2041        if b.size(0) == 1:
2042            return torch.tensor(True)
2043        else:
2044            return torch.tensor(False)
2045
2046    @make_test
2047    def test_numpy_size(x):
2048        a = x.numpy()
2049        return a.size
2050
2051    @make_test
2052    def test_numpy_attributes(x):
2053        a = x.numpy()
2054        return (
2055            a.itemsize,
2056            a.strides,
2057            a.shape,
2058            a.ndim,
2059            a.size,
2060            torch.from_numpy(a.T),
2061            torch.from_numpy(a.real),
2062            torch.from_numpy(a.imag),
2063        )
2064
2065    @make_test
2066    def test_mean_sum_np(x: torch.Tensor):
2067        x_mean = np.mean(x.numpy(), 1)
2068        x_sum = np.sum(x_mean)
2069        x_sum_array = np.asarray(x_sum)
2070        return torch.from_numpy(x_sum_array)
2071
2072    @make_test
2073    def test_return_numpy_ndarray(x):
2074        a = x.numpy()
2075        return a.T
2076
2077    @make_test
2078    def test_return_multiple_numpy_ndarray(x):
2079        a = x.numpy()
2080        return a.T, a.imag, a.real
2081
2082    @make_test
2083    def test_ndarray_method(x):
2084        a = x.numpy()
2085        return a.copy()
2086
2087    @make_test
2088    def test_ndarray_transpose(x):
2089        a = x.numpy()
2090        return a.transpose(0, 1)
2091
2092    @make_test
2093    def test_ndarray_reshape(x):
2094        a = x.numpy()
2095        return a.reshape([1, a.size])
2096
2097    @make_test
2098    def test_ndarray_methods_returning_scalar(x):
2099        a = x.numpy()
2100        return a.max(axis=0), a.all(axis=0)
2101
2102    @make_test
2103    def test_ndarray_builtin_functions(x):
2104        a = x.numpy()
2105        return a + a, a - a
2106
2107    @make_test
2108    def test_numpy_dtype_argument_to_function(x):
2109        return np.ones_like(x, dtype=np.float64)
2110
2111    @make_test
2112    def test_numpy_dtype_call_in_function(x):
2113        dt = np.dtype("float")
2114        return np.full_like(x, 2.4, dtype=dt)
2115
2116    @make_test
2117    def test_numpy_linalg(x):
2118        return np.linalg.norm(x.numpy(), axis=0)
2119
2120    @make_test
2121    def test_numpy_fft(x):
2122        return np.fft.fftshift(x.numpy())
2123
2124    @make_test
2125    def test_numpy_random():
2126        x = np.random.randn(2, 2)
2127        return x - x
2128
2129    @make_test
2130    def test_partials_torch_op_kwarg(x):
2131        par_mul = functools.partial(torch.mul, other=torch.ones(10, 10))
2132        return par_mul(x)
2133
2134    @make_test
2135    def test_partials_torch_op_arg(x):
2136        par_mul = functools.partial(torch.mul, torch.ones(10, 10))
2137        return par_mul(x)
2138
2139    @make_test
2140    def test_partials_udf_arg(x):
2141        par_mul = functools.partial(udf_mul, torch.ones(10, 10))
2142        return par_mul(x)
2143
2144    @make_test
2145    def test_list_add_then_mutate(x):
2146        my_list = [1, x]
2147        y = x / 4.0
2148        my_list = my_list + [x / 2.0, 4]
2149        my_list.append(y)
2150        return sum(my_list)
2151
2152    @make_test
2153    def test_list_expand_lhs(x):
2154        return sum(4 * [x])
2155
2156    @make_test
2157    def test_in_not_in(x):
2158        mylist = [1, 2, 3, 4, 5, x]
2159        myotherlist = [1, 2, 3, 4, 5]
2160        assert 3 in mylist
2161        assert 6 not in myotherlist
2162        return sum(mylist)
2163
2164    @make_test
2165    def test_are_functorch_transforms_active(x):
2166        if torch._C._are_functorch_transforms_active():
2167            return x + 1
2168        else:
2169            return x - 1
2170
2171    @make_test
2172    def test_partials_udf_kwarg(x):
2173        par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))
2174        return par_mul(x)
2175
2176    @make_test
2177    def test_partials_udf_kwarg_module(x, y):
2178        par_mod = functools.partial(udf_module, mod=SmallNN())
2179        return par_mod(x=x, y=y)
2180
2181    @make_test
2182    def test_partials_udf_kwarg_method(x, y):
2183        par_mod = functools.partial(udf_module, mod=SmallNN().forward)
2184        return par_mod(x=x, y=y)
2185
2186    @make_test
2187    def test_partials_lambda(x):
2188        multiply = lambda x, y: x * y
2189        triple = functools.partial(multiply, y=3)
2190        return triple(x)
2191
2192    @unittest.skipUnless(torch.distributed.is_available(), "requires torch.distributed")
2193    @make_test
2194    def test_flat_param_same_storage_size(x, y):
2195        import torch.distributed.fsdp._flat_param as flat_param
2196
2197        if flat_param._same_storage_size(x, 100):
2198            x = x + 1
2199        else:
2200            x = x - 1
2201        if flat_param._same_storage_size(y, 123):
2202            y = y + 1
2203        else:
2204            y = y - 1
2205        return x, y
2206
2207    @parametrize(
2208        "attr",
2209        (
2210            # True
2211            "__subclasshook__",
2212            "__lt__",
2213            "__hash__",
2214            "__ge__",
2215            "__le__",
2216            "__gt__",
2217            "__dict__",
2218            "__getattribute__",
2219            "__setattr__",
2220            "__doc__",
2221            "__repr__",
2222            "__dir__",
2223            "__init__",
2224            "__new__",
2225            "__class__",
2226            "__eq__",
2227            "__delattr__",
2228            "__reduce__",
2229            "__module__",
2230            "__format__",
2231            "__str__",
2232            "__sizeof__",
2233            "__ne__",
2234            "__call__",
2235            "__reduce_ex__",
2236            "__init_subclass__",
2237            "args",
2238            "keywords",
2239            "func",
2240            # False
2241            "__code__",
2242            "__kwdefaults__",
2243            "__defaults__",
2244            "__name__",
2245            "__annotations__",
2246            "__get__",
2247            "__builtins__",
2248            "__qualname__",
2249            "__globals__",
2250            "__closure__",
2251        ),
2252    )
2253    def test_partials_hasattr(self, attr):
2254        def fn(t):
2255            f = lambda x, y: torch.sin(x) + torch.cos(y)
2256            p = functools.partial(f, y=t)
2257            if hasattr(p, attr):
2258                return p(t)
2259            else:
2260                return torch.zeros_like(t)
2261
2262        t = torch.randn(3, 4)
2263        counter = torch._dynamo.testing.CompileCounter()
2264        opt_fn = torch.compile(fullgraph=True, backend=counter)(fn)
2265        self.assertEqual(opt_fn(t), fn(t))
2266        self.assertGreater(counter.frame_count, 0)
2267
2268    @unittest.expectedFailure
2269    def test_partials_hasattr_set_attr(self):
2270        def fn(t):
2271            f = lambda x, y: torch.sin(x) + torch.cos(y)
2272            p = functools.partial(f, y=t)
2273            p.__name__ = "test"
2274            if hasattr(p, "__name__"):
2275                return p(t)
2276            else:
2277                return torch.zeros_like(t)
2278
2279        t = torch.randn(3, 4)
2280        counter = torch._dynamo.testing.CompileCounter()
2281        opt_fn = torch.compile(fullgraph=True, backend=counter)(fn)
2282        self.assertEqual(opt_fn(t), fn(t))
2283
2284    def test_filter(self):
2285        def fn(inputs):
2286            out = inputs[0]
2287            for inp in filter(lambda x: (x.requires_grad), inputs):
2288                out = out * inp
2289            return out
2290
2291        input1 = torch.arange(2, dtype=torch.bfloat16)
2292        input2 = torch.arange(2, dtype=torch.bfloat16).requires_grad_(True)
2293        inputs = [input1, input2]
2294
2295        opt_fn = torch.compile(fullgraph=True)(fn)
2296        self.assertEqual(opt_fn(inputs), fn(inputs))
2297
2298    def test_filter_fallback(self):
2299        def fn(inputs):
2300            out = inputs[0]
2301            for inp in filter(lambda x: x[0] == 1, inputs):
2302                out = out * inp
2303            return out
2304
2305        input1 = torch.ones(2, dtype=torch.bfloat16)
2306        input2 = torch.arange(2, dtype=torch.bfloat16)
2307        inputs = [input1, input2]
2308
2309        opt_fn = torch.compile()(fn)
2310        self.assertEqual(opt_fn(inputs), fn(inputs))
2311
2312        torch._dynamo.reset()
2313
2314        with self.assertRaises(torch._dynamo.exc.Unsupported):
2315            opt_fn = torch.compile(fullgraph=True)(fn)
2316            opt_fn(inputs)
2317
2318    def test_pow_int(self):
2319        def fn(a, b):
2320            return torch.pow(a, b)
2321
2322        x = torch.ones(2, 2)
2323        opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn)
2324        self.assertEqual(opt_fn(x, 2), fn(x, 2))
2325
2326    def test_tensor_size_indexed_by_symint(self):
2327        def fn(x, y):
2328            index = x.shape[-1]
2329            return x + y.shape[index]
2330
2331        x = torch.rand(10, 2)
2332        y = torch.rand(10, 8, 6)
2333        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
2334        self.assertEqual(opt_fn(x, y), fn(x, y))
2335
2336    def test_partials_as_input_partials_lambda(self):
2337        def fn(f0, f1, x):
2338            return f0(x) * f1(x)
2339
2340        multiply = lambda x, y: x * y
2341        lambda0 = functools.partial(multiply, y=3)
2342        lambda1 = functools.partial(multiply, y=2)
2343
2344        cnts = torch._dynamo.testing.CompileCounter()
2345        torch._dynamo.optimize(cnts, nopython=True)(fn)(
2346            lambda0, lambda1, torch.randn(2, 2)
2347        )
2348        self.assertEqual(cnts.frame_count, 1)
2349
2350    def test_partials_as_input_partials_mod(self):
2351        def fn(f0, f1, x):
2352            return f0(x) * f1(x)
2353
2354        lambda0 = functools.partial(SmallNN(), y=torch.randn(2, 2))
2355        lambda1 = functools.partial(SmallNN(), y=torch.randn(2, 2))
2356
2357        cnts = torch._dynamo.testing.CompileCounter()
2358        x = torch.randn(2, 2)
2359        dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)(
2360            lambda0, lambda1, x
2361        )
2362        self.assertEqual(cnts.frame_count, 1)
2363
2364        eager_result = fn(lambda0, lambda1, x)
2365        self.assertEqual(eager_result, dynamo_result)
2366
2367    def test_partials_as_input_UDF(self):
2368        def fn(f0, f1, x):
2369            return f0(x) * f1(x)
2370
2371        lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2))
2372        lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2))
2373
2374        cnts = torch._dynamo.testing.CompileCounter()
2375        x = torch.randn(2, 2)
2376        dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)(
2377            lambda0, lambda1, x
2378        )
2379        self.assertEqual(cnts.frame_count, 1)
2380
2381        eager_result = fn(lambda0, lambda1, x)
2382        self.assertEqual(eager_result, dynamo_result)
2383
2384    def test_partials_graph_break_reconstruct(self):
2385        def fn(udf_mul_0, udf_mul_1, x):
2386            lambda0 = functools.partial(udf_mul_0, y=x)
2387            lambda1 = functools.partial(udf_mul_1, y=x)
2388
2389            print("break")
2390            return torch.mul(lambda0(x), lambda1(x))
2391
2392        backend = EagerAndRecordGraphs()
2393        cnts = CompileCounterWithBackend(backend)
2394        x = torch.randn(2, 2)
2395        dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_mul, x)
2396
2397        eager_result = fn(udf_mul, udf_mul, x)
2398        gm = backend.graphs[0]
2399        self.assertEqual(eager_result, dynamo_result)
2400        if torch._dynamo.config.assume_static_by_default:
2401            self.assertExpectedInline(
2402                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2403                """\
2404class GraphModule(torch.nn.Module):
2405    def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
2406        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2407
2408        mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2409        mul_1: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2410
2411        mul_2: "f32[2, 2]" = torch.mul(mul, mul_1);  mul = mul_1 = None
2412        return (mul_2,)
2413""",
2414            )
2415        else:
2416            self.assertExpectedInline(
2417                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2418                """\
2419class GraphModule(torch.nn.Module):
2420    def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2421        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2422
2423        mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2424        mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2425
2426        mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1);  mul = mul_1 = None
2427        return (mul_2,)
2428""",
2429            )
2430
2431    def test_partials_graph_break_reconstruct_mix(self):
2432        def fn(udf_mul_0, udf_add_1, x):
2433            lambda0 = functools.partial(udf_mul_0, y=x)
2434            lambda1 = functools.partial(udf_add_1, x)
2435
2436            print("break")
2437            return torch.mul(lambda0(x), lambda1(x))
2438
2439        backend = EagerAndRecordGraphs()
2440        cnts = CompileCounterWithBackend(backend)
2441        x = torch.randn(2, 2)
2442        dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_add, x)
2443
2444        eager_result = fn(udf_mul, udf_add, x)
2445        gm = backend.graphs[0]
2446        self.assertEqual(eager_result, dynamo_result)
2447        if torch._dynamo.config.assume_static_by_default:
2448            self.assertExpectedInline(
2449                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2450                """\
2451class GraphModule(torch.nn.Module):
2452    def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
2453        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2454
2455        mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2456
2457        add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2458
2459        mul_1: "f32[2, 2]" = torch.mul(mul, add);  mul = add = None
2460        return (mul_1,)
2461""",
2462            )
2463        else:
2464            self.assertExpectedInline(
2465                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2466                """\
2467class GraphModule(torch.nn.Module):
2468    def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2469        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2470
2471        mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2472
2473        add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2474
2475        mul_1: "f32[s0, s0]" = torch.mul(mul, add);  mul = add = None
2476        return (mul_1,)
2477""",
2478            )
2479
2480    def test_partials_graph_break_reconstruct_mix_no_source(self):
2481        def fn(udf_mul_0, x):
2482            udf_add_1 = lambda x, y: x + y
2483
2484            lambda0 = functools.partial(udf_mul_0, y=x)
2485            lambda1 = functools.partial(udf_add_1, x)
2486
2487            print("break")
2488            return torch.mul(lambda0(x), lambda1(x))
2489
2490        backend = EagerAndRecordGraphs()
2491        cnts = CompileCounterWithBackend(backend)
2492        x = torch.randn(2, 2)
2493        dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, x)
2494
2495        eager_result = fn(udf_mul, x)
2496        gm = backend.graphs[0]
2497        self.assertEqual(eager_result, dynamo_result)
2498        if torch._dynamo.config.assume_static_by_default:
2499            self.assertExpectedInline(
2500                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2501                """\
2502class GraphModule(torch.nn.Module):
2503    def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
2504        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2505
2506        mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2507
2508        add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2509
2510        mul_1: "f32[2, 2]" = torch.mul(mul, add);  mul = add = None
2511        return (mul_1,)
2512""",
2513            )
2514        else:
2515            self.assertExpectedInline(
2516                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2517                """\
2518class GraphModule(torch.nn.Module):
2519    def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2520        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2521
2522        mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2523
2524        add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2525
2526        mul_1: "f32[s0, s0]" = torch.mul(mul, add);  mul = add = None
2527        return (mul_1,)
2528""",
2529            )
2530
2531    def test_partials_graph_break_reconstruct_args_and_kwargs(self):
2532        def fn(udf_mul_0, x):
2533            lambda0 = functools.partial(udf_mul_0, x, 4, z=x)
2534            lambda1 = functools.partial(udf_mul_0, 4, z=x)
2535
2536            return torch.mul(lambda0(), lambda1(5))
2537
2538        backend = EagerAndRecordGraphs()
2539        cnts = CompileCounterWithBackend(backend)
2540        x = torch.randn(2, 2)
2541        dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul2, x)
2542
2543        eager_result = fn(udf_mul2, x)
2544        gm = backend.graphs[0]
2545        self.assertEqual(eager_result, dynamo_result)
2546        if torch._dynamo.config.assume_static_by_default:
2547            self.assertExpectedInline(
2548                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2549                """\
2550class GraphModule(torch.nn.Module):
2551    def forward(self, L_x_: "f32[2, 2]"):
2552        l_x_ = L_x_
2553
2554        mul: "f32[2, 2]" = l_x_ * 4
2555        mul_1: "f32[2, 2]" = mul * l_x_;  mul = None
2556        mul_2: "f32[2, 2]" = 20 * l_x_;  l_x_ = None
2557
2558        mul_3: "f32[2, 2]" = torch.mul(mul_1, mul_2);  mul_1 = mul_2 = None
2559        return (mul_3,)
2560""",
2561            )
2562        else:
2563            self.assertExpectedInline(
2564                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2565                """\
2566class GraphModule(torch.nn.Module):
2567    def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"):
2568        l_x_ = L_x_
2569
2570        mul: "f32[s0, s0]" = l_x_ * 4
2571        mul_1: "f32[s0, s0]" = mul * l_x_;  mul = None
2572        mul_2: "f32[s0, s0]" = 20 * l_x_;  l_x_ = None
2573
2574        mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2);  mul_1 = mul_2 = None
2575        return (mul_3,)
2576""",
2577            )
2578
2579    def test_partials_recompilation(self):
2580        def fn(f0, f1, x):
2581            return f0(x) * f1(x)
2582
2583        lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2))
2584        lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2))
2585
2586        cnts = torch._dynamo.testing.CompileCounter()
2587
2588        x = torch.randn(2, 2)
2589        fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2590        dynamo_result = fn(lambda0, lambda1, x)
2591        self.assertEqual(cnts.frame_count, 1)
2592
2593        fn(lambda1, lambda0, x)
2594        self.assertEqual(
2595            cnts.frame_count, 1
2596        )  # No recompile! Tensor and udf_mul guarded
2597
2598        lambda2 = functools.partial(udf_mul, y=torch.randn(3, 3))
2599        x = torch.randn(3, 3)
2600        fn(lambda2, lambda2, x)
2601        self.assertEqual(cnts.frame_count, 2)  # Recompile! Tensor size changed
2602
2603        multiply = lambda x, y: x * y
2604        lambda3 = functools.partial(multiply, y=torch.randn(3, 3))
2605        x = torch.randn(3, 3)
2606        fn(lambda3, lambda3, x)
2607
2608        self.assertEqual(cnts.frame_count, 3)  # Recompile! func id changed
2609
2610        def fn2(f0, f1, args):
2611            return f0(*args) * f1(*args)
2612
2613        cnts = torch._dynamo.testing.CompileCounter()
2614
2615        x = torch.randn(2, 2)
2616        fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
2617        dynamo_result = fn2(lambda0, lambda1, [x])
2618        self.assertEqual(cnts.frame_count, 1)  # start over
2619
2620        lambda4 = functools.partial(multiply, y=3, x=torch.randn(3, 3))
2621        fn2(lambda4, lambda4, [])
2622
2623        self.assertEqual(cnts.frame_count, 2)  # Recompile! Different kwarg keys
2624
2625        lambda5 = functools.partial(multiply, 1)
2626        x = torch.randn(3, 3)
2627        fn2(lambda5, lambda5, [x])
2628
2629        self.assertEqual(cnts.frame_count, 3)  # Recompile! Different arg keys
2630
2631        lambda6 = lambda x: x + x
2632        fn2(lambda6, lambda6, [x])
2633        self.assertEqual(
2634            cnts.frame_count, 4
2635        )  # Recompile! input is no longer a functools partial
2636
2637    def test_manual_seed(self):
2638        @torch.compile
2639        def foo():
2640            torch.manual_seed(3)
2641            return torch.randint(0, 5, (5,))
2642
2643        self.assertEqual(foo(), foo())
2644        self.assertEqual(foo(), foo())
2645
2646    def test_partial_across_graph_break_uninvoked(self):
2647        from functools import partial
2648
2649        def bar(x, **kwargs):
2650            return x + x
2651
2652        @torch.compile(backend="eager", dynamic=True)
2653        def foo(x, i):
2654            def inner():
2655                print("this is a graph_break")
2656                return op(x)
2657
2658            op = partial(bar, dim=10)
2659            x = inner()
2660            op = partial(bar, other=10)
2661            return inner() + x
2662
2663        foo(torch.rand(1), 10)
2664
2665    def test_no_recompile_inner_function(self):
2666        def forward(inp):
2667            def g(y):
2668                return inp + y
2669
2670            print("graph break")
2671            return g(torch.rand([1]))
2672
2673        cnts = torch._dynamo.testing.CompileCounter()
2674        opt_fn = torch._dynamo.optimize(cnts)(forward)
2675
2676        input = torch.rand([2])
2677        _ = opt_fn(input)
2678        _ = opt_fn(input)
2679        _ = opt_fn(input)
2680        # Should not have recompiled
2681        self.assertEqual(cnts.frame_count, 1)
2682
2683    def test_no_recompile_inner_lambda(self):
2684        def forward(inp):
2685            g = lambda y: inp + y
2686            print("graph break")
2687            return g(torch.rand([1]))
2688
2689        cnts = torch._dynamo.testing.CompileCounter()
2690        opt_fn = torch._dynamo.optimize(cnts)(forward)
2691
2692        input = torch.rand([2])
2693        _ = opt_fn(input)
2694        _ = opt_fn(input)
2695        _ = opt_fn(input)
2696        # Should not have recompiled
2697        self.assertEqual(cnts.frame_count, 1)
2698
2699    def test_complex_closure(self):
2700        @torch.compile
2701        def forward(y):
2702            def a():
2703                def x(z):
2704                    return y + z
2705
2706                return x
2707
2708            return a()
2709
2710        input1 = torch.rand([2])
2711        input2 = torch.rand([2])
2712        res = forward(input1)(input2)
2713        self.assertTrue(same(res, input1 + input2))
2714
2715    def test_non_inlined_closure(self):
2716        @torch.compile()
2717        def program(x, y):
2718            one = lambda x, y: x + y
2719
2720            def inner():
2721                # Force no inlining
2722                torch._dynamo.graph_break()
2723                return one(x, y)
2724
2725            res = inner()
2726            one = lambda x, y: x - y
2727            res += inner()
2728            return res
2729
2730        input1 = torch.randn(1)
2731        input2 = torch.randn(1)
2732
2733        self.assertTrue(same(program(input1, input2), input1 + input1))
2734
2735    @parametrize("int_or_float", ("int", "float"))
2736    def test_np_constant_collections_as_input(self, int_or_float):
2737        info_func = getattr(np, f"{int_or_float[0]}info")
2738        dt_string_arg = f"{int_or_float}16"
2739        np_dt_attr = getattr(np, dt_string_arg)
2740
2741        dt_args = [dt_string_arg, np_dt_attr]
2742        arg_variants_iter = itertools.chain(
2743            dt_args, map(np.dtype, dt_args), map(info_func, dt_args)
2744        )
2745
2746        def func(a, b, info_or_dt):
2747            return a + info_func(info_or_dt).max
2748
2749        opt_fn = torch.compile(func)
2750
2751        a = torch.randn(2)
2752        b = torch.randn(2)
2753        eager_result = func(a, b, dt_args[0])
2754
2755        for arg in arg_variants_iter:
2756            opt_result = opt_fn(a, b, arg)
2757            self.assertTrue(same(opt_result, eager_result))
2758
2759    @parametrize(
2760        "typ, info_func",
2761        [
2762            (int, np.iinfo),
2763            (float, np.finfo),
2764        ],
2765        name_fn=lambda t, _: t.__name__,
2766    )
2767    def test_np_constant_collections_guards(self, typ, info_func):
2768        def func_info(a, info):
2769            return a + info.max
2770
2771        def func_dtype(a, dt):
2772            return a + info_func(dt).max
2773
2774        dt_args = [
2775            np.dtype(typ),
2776            np.ones((1,), dtype=typ).dtype,
2777            np.dtype(np.dtype(typ).name),
2778            np.dtype(typ.__name__),
2779        ]
2780        cnts_1 = torch._dynamo.testing.CompileCounter()
2781        opt_fn_dtype = torch._dynamo.optimize(cnts_1)(func_dtype)
2782        a = torch.zeros(3, dtype=typ)
2783        for arg in dt_args:
2784            r = opt_fn_dtype(a, arg)
2785        # each should produce an identical arg
2786        self.assertEqual(cnts_1.frame_count, 1)
2787
2788        cnts_2 = torch._dynamo.testing.CompileCounter()
2789        opt_fn_info = torch._dynamo.optimize(cnts_2)(func_info)
2790        info_args = [info_func(dt) for dt in dt_args]
2791        for arg in info_args:
2792            r = opt_fn_info(a, arg)
2793
2794        # each should produce an identical arg
2795        self.assertEqual(cnts_2.frame_count, 1)
2796
2797        if typ is float:
2798            dt_extra = np.dtype(np.float16)
2799        else:
2800            dt_extra = np.dtype(np.int16)
2801        info_extra = info_func(dt_extra)
2802
2803        eager_result_dtype = func_dtype(a, dt_extra)
2804        compile_result_dtype = opt_fn_dtype(a, dt_extra)
2805        self.assertEqual(cnts_1.frame_count, 2)
2806        self.assertEqual(eager_result_dtype, compile_result_dtype)
2807
2808        eager_result_info = func_info(a, info_extra)
2809        compile_result_info = opt_fn_info(a, info_extra)
2810        self.assertEqual(cnts_2.frame_count, 2)
2811        self.assertEqual(eager_result_info, compile_result_info)
2812
2813    def test_compare_constant_and_tensor(self):
2814        for op in [
2815            operator.lt,
2816            operator.le,
2817            operator.gt,
2818            operator.ge,
2819            operator.ne,
2820            operator.eq,
2821            operator.is_,
2822            operator.is_not,
2823        ]:
2824            with self.subTest(op=op):
2825
2826                def fn(x):
2827                    return op(-10, x)
2828
2829                opt_fn = torch.compile(fullgraph=True)(fn)
2830
2831                x = torch.randn(10)
2832                self.assertEqual(opt_fn(x), fn(x))
2833
2834    def test_pos(self):
2835        def fn(x, y):
2836            return operator.pos(x) * +y
2837
2838        opt_fn = torch.compile(fullgraph=True, dynamic=True)(fn)
2839
2840        def test(x, y):
2841            self.assertEqual(opt_fn(x, y), fn(x, y))
2842
2843        test(torch.ones(4), 1)
2844        test(1, torch.ones(4))
2845        test(-1, -1)
2846        test(-1.1, 1.1)
2847        test(True, False)
2848        test(torch.ones(4, dtype=torch.float32), 1.1)
2849
2850    def test_index(self):
2851        def fn(x, t):
2852            v = operator.index(x)
2853            torch.mul(t, v)
2854
2855        def test(a, b):
2856            self.assertEqual(opt_fn(a, b), fn(a, b))
2857
2858        for dynamic in [True, False]:
2859            torch._dynamo.reset()
2860            opt_fn = torch._dynamo.optimize(dynamic=dynamic)(fn)
2861            t = torch.ones(1)
2862            test(10, t)
2863            test(-100, t)
2864            test(10, t)
2865            test(False, t)
2866            test(True, t)
2867
2868    def test_truth(self):
2869        def fn(x, y):
2870            return operator.truth(x) and bool(y)
2871
2872        opt_fn = torch.compile(fullgraph=True, dynamic=False)(fn)
2873
2874        def test(x, y):
2875            self.assertEqual(opt_fn(x, y), fn(x, y))
2876
2877        test(1, 100)
2878        test(-1.1, True)
2879        test(-1.1, 1.1)
2880        test(True, False)
2881        test(torch.ones(1), 1)
2882        test(torch.zeros(1), 1)
2883        test(torch.ones(1), torch.ones(1))
2884
2885    def test_unary_fold_op(self):
2886        for op in (operator.abs, abs, operator.neg, operator.pos, operator.truth):
2887            with self.subTest(op=op):
2888
2889                def fn():
2890                    a = range(-10, 10)
2891                    return list(map(op, a))
2892
2893                opt_fn = torch._dynamo.optimize(nopython=True)(fn)
2894                self.assertEqual(opt_fn(), fn())
2895
2896    def test_unary_fold_op_seq(self):
2897        for op in (operator.length_hint,):
2898            with self.subTest(op=op):
2899
2900                def fn():
2901                    a = [tuple(range(-10, i)) for i in range(10)]
2902                    return tuple(map(op, a))
2903
2904                opt_fn = torch._dynamo.optimize(nopython=True)(fn)
2905                self.assertEqual(opt_fn(), fn())
2906
2907    def gen_random_range_args(self):
2908        args_count = random.randint(1, 3)
2909        args = [random.randint(-10, 10) for _ in range(args_count)]
2910        if args_count == 3 and args[2] == 0:
2911            args[2] = 1
2912        return args
2913
2914    def test_range_length(self):
2915        def test(*args, expected=None):
2916            r = range(*args)
2917            range_variable = RangeVariable([ConstantVariable.create(v) for v in args])
2918
2919            self.assertEqual(len(r), range_variable.range_length())
2920
2921            if expected is not None:
2922                self.assertEqual(len(r), expected)
2923
2924        test(1, 1, 1, expected=0)
2925        test(1, 0, expected=0)
2926        test(-10, expected=0)
2927
2928        test(4, expected=4)
2929        test(10, expected=10)
2930
2931        # step >1
2932        test(1, 10, 2, expected=5)
2933
2934        # negative step
2935        test(10, 1, -1, expected=9)
2936        test(10, 1, -3)
2937
2938        # Fuzz testing
2939        for i in range(100):
2940            args = self.gen_random_range_args()
2941            print("testing :", args)
2942            test(*args)
2943
2944    def test_indexed_range(self):
2945        def test(range, index, expected=None):
2946            range_variable = RangeVariable(
2947                [
2948                    ConstantVariable.create(v)
2949                    for v in [range.start, range.stop, range.step]
2950                ]
2951            )
2952
2953            self.assertEqual(
2954                range[index],
2955                range_variable.apply_index(index).as_python_constant(),
2956            )
2957
2958            if expected is not None:
2959                self.assertEqual(range[index], expected)
2960
2961        test(range(10), 1, expected=1)
2962        test(range(10, 20, 2), 1, expected=12)
2963
2964        # Fuzz testing
2965        for i in range(100):
2966            range_args = self.gen_random_range_args()
2967            r = range(*range_args)
2968
2969            if len(r) == 0:
2970                continue
2971
2972            index = random.randint(0, len(r) - 1)
2973
2974            print("testing:", r, index)
2975            test(r, index)
2976
2977    def test_sliced_range(self):
2978        def test(range, slice, expected=None):
2979            range_variable = RangeVariable(
2980                [
2981                    ConstantVariable.create(v)
2982                    for v in [range.start, range.stop, range.step]
2983                ]
2984            )
2985
2986            self.assertEqual(
2987                range[slice],
2988                range_variable.apply_slice(slice).as_python_constant(),
2989            )
2990
2991            if expected is not None:
2992                self.assertEqual(
2993                    range[slice],
2994                    expected,
2995                )
2996
2997        test(range(10), slice(1, 10, 2), expected=range(1, 10, 2))
2998        test(range(10), slice(None, 10, None), expected=range(0, 10))
2999        test(range(10), slice(-1, 7, None), expected=range(9, 7))
3000        test(range(10), slice(-1, 7, 2), expected=range(9, 7, 2))
3001        test(range(1, 10, 2), slice(3, 7, 2), expected=range(7, 11, 4))
3002        test(range(1, 10, 2), slice(-3, 7, 2), expected=range(5, 11, 4))
3003        test(range(-1, -5, -3), slice(5, None, -3), expected=range(-4, 2, 9))
3004
3005        def rand_slice():
3006            def flip_coin():
3007                # 1 out of 10
3008                return random.randint(1, 10) == 5
3009
3010            def r_item(allow_zero=True):
3011                i = random.randint(-10, 10)
3012                if not allow_zero and i == 0:
3013                    i = 1
3014                if flip_coin():
3015                    i = None
3016                return i
3017
3018            arg_count = random.randint(1, 3)
3019
3020            if arg_count == 1:
3021                return slice(r_item())
3022            elif arg_count == 2:
3023                return slice(r_item(), r_item())
3024            else:
3025                return slice(r_item(), r_item(), r_item(False))
3026
3027        # Fuzz testing
3028        for i in range(100):
3029            range_args = self.gen_random_range_args()
3030            r = range(*range_args)
3031            # generate random slice
3032            s = rand_slice()
3033
3034            print("testing:", r, s)
3035            test(r, s)
3036
3037    def test_range_with_slice_index(self):
3038        def fn(x):
3039            acc = 1
3040            for k in range(2)[1::2]:
3041                acc *= acc * k
3042            return x * acc
3043
3044        opt_fn = torch.compile(fullgraph=True)(fn)
3045        x = torch.ones(1)
3046        self.assertEqual(opt_fn(x), fn(x))
3047
3048    def test_range_with_index(self):
3049        def fn(x):
3050            acc = 1
3051            acc *= acc * range(10, 20, 2)[2]
3052            return x * acc
3053
3054        opt_fn = torch.compile(fullgraph=True)(fn)
3055        x = torch.ones(1)
3056        self.assertEqual(opt_fn(x), fn(x))
3057
3058    def test_rand_inlined(self):
3059        @torch.compile(backend="eager", dynamic=True)
3060        def fn():
3061            idx_size = [10]
3062            idx_size[random.randint(0, 0)] = random.randint(1, 8)
3063            t = tuple(idx_size)
3064            src_size = [random.randint(1, 5) + s for s in idx_size]
3065            idx = torch.empty(t)
3066
3067        fn()
3068
3069    def test_rand_tensor_partial(self):
3070        from collections import namedtuple
3071        from functools import partial
3072
3073        SdpaShape = namedtuple(
3074            "Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"]
3075        )
3076
3077        @torch.compile(backend="eager")
3078        def func():
3079            make_tensor = partial(
3080                torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
3081            )
3082
3083            bsz, num_heads, seq_len_q, seq_len_kv, head_dim = (16, 16, 128, 128, 16)
3084            make_q_tensor = partial(
3085                make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)
3086            )
3087            make_kv_tensor = partial(
3088                make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)
3089            )
3090            t1 = make_q_tensor()
3091            t2 = make_kv_tensor()
3092            t3 = t1 + t2
3093
3094        func()
3095
3096    def test_to(self):
3097        @torch.compile(backend="eager")
3098        def fn():
3099            t = torch.ones(2)
3100            y = t.to("meta")
3101
3102        fn()
3103
3104    def test_elipsis(self):
3105        @torch.compile(backend="eager", fullgraph=True)
3106        def fn(a, ind, val):
3107            a[ind] = val
3108            return a
3109
3110        arr = np.zeros(4)
3111        self.assertEqual(fn(arr, np.s_[...], np.ones(4)), np.ones(4))
3112
3113        arr = np.array([[1, 1], [2, 2]])
3114        self.assertEqual(
3115            fn(arr, np.s_[0, ...], np.zeros(2)), np.array([[0, 0], [2, 2]])
3116        )
3117
3118        arr = np.array([[1, 1], [2, 2]])
3119        self.assertEqual(
3120            fn(arr, np.s_[1, ...], np.zeros(2)), np.array([[1, 1], [0, 0]])
3121        )
3122
3123        arr = np.array([[1, 1], [2, 2]])
3124        self.assertEqual(
3125            fn(arr, np.s_[..., 0], np.array([3, 3])), np.array([[3, 1], [3, 2]])
3126        )
3127
3128        arr = np.array([[1, 1], [2, 2]])
3129        self.assertEqual(
3130            fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]])
3131        )
3132
3133    def test_map_return(self):
3134        def fn(a, b):
3135            return map(lambda x: x + 1, [a, b])
3136
3137        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3138        m = opt_fn(torch.randn(3, 3), torch.randn(3, 3))
3139        self.assertIsInstance(m, map)
3140
3141    @make_test
3142    def test_map_max(a, b):
3143        return max(map(lambda x: x.sum(), [a, b]))
3144
3145    # max(map(...)) graph breaks
3146    @unittest.expectedFailure
3147    @make_test
3148    def test_map_max_const(a):
3149        return max(map(lambda x: x, [1, 2, 3])), a + 1
3150
3151    @make_test
3152    def test_map_list(a, b):
3153        return list(map(lambda x: x + 1, [a, b]))
3154
3155    @make_test
3156    def test_map_tuple(a, b):
3157        return tuple(map(lambda x: x + 1, [a, b]))
3158
3159    @make_test
3160    def test_map_iter(a, b):
3161        it = iter(map(lambda x: x + 1, [a, b]))
3162        return next(it)
3163
3164    @make_test
3165    def test_map_zip_dict(a):
3166        d = dict(
3167            zip(
3168                map(lambda x: x + 1, [0, 1, 2]),
3169                [map(lambda x: x - 1, [y]) for y in [3, 4, 5]],
3170            )
3171        )
3172        return list(d[3])[0], a + 1  # noqa: RUF015
3173
3174    @make_test
3175    def test_map_dict_fromkeys(a):
3176        return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1
3177
3178    @make_test
3179    def test_map_set(a):
3180        return set(map(lambda x: x + 1, [0, 1])), a + 1
3181
3182    # test_map_sum defined earlier
3183
3184    @make_test
3185    def test_map_reduce(a, b):
3186        return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b]))
3187
3188    @make_test
3189    def test_map_sorted(a):
3190        return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1
3191
3192    @make_test
3193    def test_map_list_extend(a, b, c):
3194        l = [a]
3195        l.extend(map(lambda x: x + 1, [b, c]))
3196        return l
3197
3198    @make_test
3199    def test_map_list_slice_assign(a, b, c, d, e):
3200        l = [a, b, c]
3201        l[1:2] = map(lambda x: x + 1, [d, e])
3202        return l
3203
3204    @make_test
3205    def test_map_deque_extendleft(a, b, c):
3206        d = collections.deque([a])
3207        d.extendleft(map(lambda x: x + 1, [b, c]))
3208        return d
3209
3210    @make_test
3211    def test_map_str_join(a):
3212        return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1
3213
3214    def test_map_with_graph_break(self):
3215        def f(a):
3216            a += 1
3217
3218            def g(x):
3219                nonlocal a
3220                a += 1
3221                return x + 1
3222
3223            m = map(g, [1, 2, 3, 4, 5])
3224            a += next(m)  # won't graph break
3225            torch._dynamo.graph_break()
3226            a += next(m)  # will graph break
3227            return a
3228
3229        cnts = torch._dynamo.testing.CompileCounter()
3230        opt_f = torch.compile(f, backend=cnts)
3231        self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3)))
3232        self.assertEqual(cnts.frame_count, 3)
3233
3234    def test_map_reconstruct(self):
3235        def fn(a):
3236            return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1
3237
3238        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3239        m = opt_fn(torch.ones(3, 3))[0]
3240        self.assertIsInstance(m, map)
3241        self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
3242
3243    def test_zip_reconstruct(self):
3244        def fn(a):
3245            return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1
3246
3247        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3248        m = opt_fn(torch.ones(3, 3))[0]
3249        self.assertIsInstance(m, zip)
3250        self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
3251
3252    @make_test
3253    def test_map_partial_unpack(a, b):
3254        y = 1
3255
3256        def f(x):
3257            nonlocal y
3258            y += 1
3259            return x
3260
3261        l = list(zip([a, b], map(f, [1, 2, 3, 4])))
3262        return a + y
3263
3264    @make_test
3265    def test_map_call_function_ex(a, b):
3266        def f(x, y):
3267            return x + y
3268
3269        return f(*map(lambda x: x + 1, [a, b]))
3270
3271    @make_test
3272    def test_map_unpack_twice(a, b):
3273        m = map(lambda x: x + 1, [a, b])
3274        l1 = list(m)
3275        l2 = list(m)
3276        return l1, l2
3277
3278    @make_test
3279    def test_enumerate(a, b):
3280        return list(enumerate([a, b], start=1)), a + 1
3281
3282    @make_test
3283    def test_map_enumerate(a, b):
3284        return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1
3285
3286    @make_test
3287    def test_map_infinite(a, b):
3288        return list(map(lambda x, y: x + y, [a, b], itertools.count(3)))
3289
3290    @make_test
3291    def test_map_unpack_vars(a, b):
3292        x, y = map(lambda x: x + 1, [a, b])
3293        return x + y
3294
3295    def test_enumerate_custom(self):
3296        class MyClass:
3297            def __iter__(self):
3298                self.a = 1
3299                return self
3300
3301            def __next__(self):
3302                if self.a > 3:
3303                    raise StopIteration
3304                self.a += 1
3305                return self.a
3306
3307        def fn(x):
3308            for i, it in enumerate(MyClass()):
3309                x += i + it
3310            return x
3311
3312        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3313        self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3)))
3314
3315    def test_enumerate_reconstruct(self):
3316        def fn(a, b):
3317            return enumerate([a, b], start=1)
3318
3319        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3320        inps = (torch.randn(3, 3), torch.randn(3, 3))
3321        it1 = fn(*inps)
3322        it2 = opt_fn(*inps)
3323        self.assertIsInstance(it2, enumerate)
3324        self.assertEqual(list(it1), list(it2))
3325
3326
3327def udf_mul(x, y):
3328    return x * y
3329
3330
3331def udf_mul2(x, y, z):
3332    return x * y * z
3333
3334
3335def udf_add(x, y):
3336    return x + y
3337
3338
3339class SmallNN(torch.nn.Module):
3340    def forward(self, x, y):
3341        combined = torch.cat((x, y), dim=1)
3342        out = torch.nn.ReLU()(combined)
3343        out = torch.nn.ReLU()(out)
3344        return out
3345
3346
3347def udf_module(mod, x, y):
3348    return mod(x, y)
3349
3350
3351def global_func_with_default_tensor_args(
3352    x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
3353):
3354    x.add_(1)
3355    kw_x.add_(1)
3356    return x, kw_x
3357
3358
3359class ModuleWithDefaultTensorArgsMethod(torch.nn.Module):
3360    def forward(self, x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))):
3361        x.add_(1)
3362        kw_x.add_(1)
3363        return x, kw_x
3364
3365
3366class WrapperModule(torch.nn.Module):
3367    def __init__(self) -> None:
3368        super().__init__()
3369        self.m = ModuleWithDefaultTensorArgsMethod()
3370
3371    def forward(self):
3372        return self.m()
3373
3374
3375class DefaultsTests(torch._dynamo.test_case.TestCase):
3376    def test_func_default_tensor_args(self):
3377        """
3378        Tests that we indeed reference (and mutate) "the one" default tensor arg
3379        stored on the globally allocated function object, both from the orig and
3380        compiled function
3381        """
3382
3383        def func():
3384            return global_func_with_default_tensor_args()
3385
3386        cnts = torch._dynamo.testing.CompileCounter()
3387        compiled_func = torch.compile(func, backend=cnts)
3388        for i in range(4):
3389            if i % 2 == 0:
3390                x, kw_x = func()
3391            else:
3392                x, kw_x = compiled_func()
3393            # the inner func mutates += 1 each call
3394            self.assertTrue(same(x, torch.ones_like(x) + i))
3395            self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
3396        # Calling compiled_func twice does not recompile
3397        self.assertEqual(cnts.frame_count, 1)
3398        self.assertEqual(cnts.op_count, 2)
3399
3400        # But with a change to the guarded default tensor, we do recompile
3401        with patch.object(
3402            global_func_with_default_tensor_args,
3403            "__defaults__",
3404            (torch.ones((3, 4, 5)),),
3405        ):
3406            x, kw_x = compiled_func()
3407        self.assertEqual(cnts.frame_count, 2)
3408        self.assertEqual(cnts.op_count, 4)
3409
3410        with patch.object(
3411            global_func_with_default_tensor_args,
3412            "__kwdefaults__",
3413            {"kw_x": torch.ones((3, 4, 5))},
3414        ):
3415            x, kw_x = compiled_func()
3416        self.assertEqual(cnts.frame_count, 3)
3417        self.assertEqual(cnts.op_count, 6)
3418
3419    def test_meth_default_tensor_args(self):
3420        """
3421        Tests that we indeed reference (and mutate) "the one" default tensor arg
3422        stored on the globally allocated function object, both from the orig and
3423        compiled function
3424        """
3425        mod = WrapperModule()
3426        cnts = torch._dynamo.testing.CompileCounter()
3427        compiled_mod = torch.compile(mod, backend=cnts)
3428        for i in range(4):
3429            if i % 2 == 0:
3430                x, kw_x = mod()
3431            else:
3432                x, kw_x = compiled_mod()
3433            # the inner func mutates += 1 each call
3434            self.assertTrue(same(x, torch.ones_like(x) + i))
3435            self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
3436        # Calling compiled_func twice does not recompile
3437        self.assertEqual(cnts.frame_count, 1)
3438        self.assertEqual(cnts.op_count, 2)
3439
3440        # But with a change to the guarded default tensor, we do recompile
3441        with patch.object(
3442            ModuleWithDefaultTensorArgsMethod.forward,
3443            "__defaults__",
3444            (torch.ones((3, 4, 5)),),
3445        ):
3446            x, kw_x = compiled_mod()
3447        self.assertEqual(cnts.frame_count, 2)
3448        self.assertEqual(cnts.op_count, 4)
3449
3450        with patch.object(
3451            ModuleWithDefaultTensorArgsMethod.forward,
3452            "__kwdefaults__",
3453            {"kw_x": torch.ones((3, 4, 5))},
3454        ):
3455            x, kw_x = compiled_mod()
3456        self.assertEqual(cnts.frame_count, 3)
3457        self.assertEqual(cnts.op_count, 6)
3458
3459    def test_func_default_torch_args(self):
3460        """
3461        Tests other types of torch types as function default (size, dtype, device)
3462        """
3463
3464        def func_with_default_torch_args(
3465            dt=torch.float16, ds=torch.Size((1, 2, 3)), dd=torch.device("cpu")
3466        ):
3467            return torch.ones(ds, dtype=dt, device=dd)
3468
3469        def func():
3470            return func_with_default_torch_args()
3471
3472        cnts = torch._dynamo.testing.CompileCounter()
3473        compiled_func = torch.compile(func, backend=cnts)
3474        out = func()
3475        compiled_out = compiled_func()
3476        self.assertEqual(out.dtype, compiled_out.dtype)
3477        self.assertEqual(out.device, compiled_out.device)
3478        self.assertEqual(out.size(), compiled_out.size())
3479        self.assertEqual(cnts.frame_count, 1)
3480        self.assertEqual(cnts.op_count, 1)
3481
3482    def test_dataclass_factory(self):
3483        @dataclass
3484        class Output:
3485            scalar: int = 2
3486            named_tensors: Dict[str, torch.Tensor] = field(default_factory=dict)
3487            lists: List[torch.Tensor] = field(default_factory=list)
3488
3489            def scale(self):
3490                return self.scalar * 2
3491
3492        def fn(x):
3493            # Check default dict assignment
3494            a = Output(1)
3495            # Check that dataclass methods can be inlined
3496            scaled_value = a.scale()
3497
3498            # Check that normal assignment works
3499            b = Output(5, named_tensors={"x": x})
3500
3501            # Check default int assignment
3502            c = Output()
3503
3504            # Check that the default members are properly initialized
3505            if isinstance(a.named_tensors, dict):
3506                x = torch.sin(x)
3507
3508            # Change dataclass
3509            c.scalar = 6
3510            c.named_tensors["x"] = x
3511
3512            # Return dataclaass as well to check reconstruction
3513            return c, torch.cos(x) * scaled_value + b.named_tensors["x"] + c.scalar
3514
3515        cnts = torch._dynamo.testing.CompileCounter()
3516        compiled_fn = torch.compile(fn, backend=cnts, fullgraph=True)
3517        x = torch.randn(4)
3518        eager_dataclass, out = fn(x)
3519        compiled_dataclass, compiled_out = compiled_fn(x)
3520        self.assertEqual(eager_dataclass.scalar, compiled_dataclass.scalar)
3521        self.assertEqual(
3522            eager_dataclass.named_tensors["x"], compiled_dataclass.named_tensors["x"]
3523        )
3524        self.assertTrue(same(out, compiled_out))
3525        self.assertEqual(cnts.frame_count, 1)
3526        self.assertEqual(cnts.op_count, 5)
3527
3528    def test_dataclass_nested(self):
3529        @dataclass
3530        class Base:
3531            outer_a: int
3532            outer_b: int
3533
3534        @dataclass
3535        class Derived(Base):
3536            inner_a: Any = field(default_factory=list)
3537
3538        def fn(x):
3539            l = Derived(1, 2)
3540            return l.outer_a * x
3541
3542        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3543        x = torch.randn(4)
3544        res = fn(x)
3545        ref = opt_fn(x)
3546        self.assertEqual(ref, res)
3547
3548    def test_listlike_of_tensors_contains_constant(self):
3549        for listlike in [set, list]:
3550
3551            def fn(x):
3552                x.add_(1)
3553                s = listlike([x])
3554                res = 1 in s
3555                return res
3556
3557            opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3558            x = torch.randn(1)
3559            ref = opt_fn(x)
3560            res = fn(x)
3561            self.assertEqual(ref, res)
3562
3563    def test_cast_tensor_single_elem(self):
3564        with torch._dynamo.config.patch({"capture_scalar_outputs": True}):
3565            for t, val in [
3566                (float, 1.0),
3567                (float, 1),
3568                (float, True),
3569                (int, 1),
3570                (int, False),
3571                # (int, 1.0), # fails due to a >= 0 comparison in sym_int
3572            ]:  # , bool, complex]: no casting for sym_bool, no sym_complex
3573
3574                def fn(x):
3575                    x = x + 1
3576                    return t(x)
3577
3578                opt_fn = torch.compile(
3579                    fn, backend="eager", fullgraph=True, dynamic=False
3580                )
3581                x = torch.tensor([val])
3582                res = fn(x)
3583                ref = opt_fn(x)
3584                self.assertEqual(ref, res)
3585
3586                # Cannot handle non single-elem
3587                with self.assertRaises(ValueError):
3588                    fn(torch.tensor([val] * 2))
3589                with self.assertRaises(torch._dynamo.exc.TorchRuntimeError):
3590                    opt_fn(torch.tensor([val] * 2))
3591
3592    def test_set_construction(self):
3593        def fn(x):
3594            y = x.add_(1)
3595            s = set({x})
3596            s.add(y)
3597            return len(s)
3598
3599        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3600        x = torch.randn(4)
3601        res = fn(x)
3602        ref = opt_fn(x)
3603        self.assertEqual(ref, res)
3604
3605    def test_frozenset_construction(self):
3606        def fn(x):
3607            s = frozenset({x})
3608            t = frozenset(s)
3609            return len(t)
3610
3611        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3612        x = torch.randn(4)
3613        res = fn(x)
3614        ref = opt_fn(x)
3615        self.assertEqual(ref, res)
3616
3617    def test_frozenset_reconstruction(self):
3618        d = {}
3619        f = frozenset()
3620        d[f] = torch.randn(4)
3621
3622        def fn(x):
3623            k = frozenset()
3624            torch._dynamo.graph_break()
3625            return d[k] * x
3626
3627        opt_fn = torch.compile(fn, backend="eager")
3628        x = torch.randn(4)
3629        res = fn(x)
3630        ref = opt_fn(x)
3631        self.assertEqual(ref, res)
3632
3633    def test_frozenset_illegal_call_method(self):
3634        def fn_add():
3635            s = frozenset((1, 2, 3))
3636            s.add({2})
3637            return len(s)
3638
3639        def fn_pop():
3640            s = frozenset((1, 2, 3))
3641            s.pop()
3642            return len(s)
3643
3644        def fn_update():
3645            s = frozenset((1, 2, 3))
3646            s.update({4, 5, 6})
3647            return len(s)
3648
3649        def fn_remove():
3650            s = frozenset((1, 2, 3))
3651            s.remove(2)
3652            return len(s)
3653
3654        def fn_discard():
3655            s = frozenset((1, 2, 3))
3656            s.discard(2)
3657            return len(s)
3658
3659        def fn_clear():
3660            s = frozenset((1, 2, 3))
3661            s.clear()
3662            return len(s)
3663
3664        for fn in [fn_add, fn_pop, fn_update, fn_remove, fn_discard, fn_clear]:
3665            torch._dynamo.reset()
3666            opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3667            with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError):
3668                opt_fn()
3669
3670    def test_is_tensor_tensor(self):
3671        def fn(x, y):
3672            if x is y:
3673                return x * 2
3674            else:
3675                return x + y
3676
3677        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3678
3679        x = torch.zeros(2)
3680        y = torch.ones(2)
3681
3682        self.assertEqual(fn(x, y), fn_opt(x, y))
3683        self.assertEqual(fn(x, x), fn_opt(x, x))
3684
3685    def test_is_not_tensor_tensor(self):
3686        def fn(x, y):
3687            if x is not y:
3688                return x * 2
3689            else:
3690                return x + y
3691
3692        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3693
3694        x = torch.zeros(2)
3695        y = torch.ones(2)
3696
3697        self.assertEqual(fn(x, y), fn_opt(x, y))
3698        self.assertEqual(fn(x, x), fn_opt(x, x))
3699
3700    def test_is_mutated_tensor_tensor(self):
3701        def fn(x):
3702            y = x.add_(1)
3703            return x is y
3704
3705        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3706
3707        z = torch.ones(4)
3708
3709        self.assertEqual(fn(z), fn_opt(z))
3710
3711    def test_is_mutated_tensor_tensor_across_graph_break(self):
3712        def fn(x):
3713            y = x.add_(1)
3714            cond = x is y
3715            x.add_(1)
3716            # The real tensor values are recovered when graph breaking.
3717            # Hence we recover the invariant.
3718            torch._dynamo.graph_break()
3719            x.add_(1)
3720            return x is y, cond
3721
3722        fn_opt = torch.compile(backend="eager", dynamic=True)(fn)
3723
3724        z = torch.ones(4)
3725
3726        self.assertEqual(fn(z), fn_opt(z))
3727
3728    def test_is_mutated_tensor_tensor(self):
3729        def fn(x):
3730            y = x.add_(1)
3731            return y is x
3732
3733        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3734
3735        z = torch.ones(4, 1)
3736
3737        self.assertEqual(fn(z), fn_opt(z))
3738
3739    def test_is_init_in_compile_mutated_tensor_tensor(self):
3740        def fn(x):
3741            z = x.clone()
3742            y = z.add_(1)
3743            return y is z
3744
3745        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3746
3747        z = torch.ones(4, 1)
3748
3749        self.assertEqual(fn(z), fn_opt(z))
3750
3751    def test_is_init_in_compile_vmapped_mutated_tensor_tensor(self):
3752        def fn(z):
3753            x = z.clone()
3754            y = torch.vmap(torch.Tensor.acos_)(x)
3755            _ = y is z
3756            return y is x
3757
3758        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3759
3760        z = torch.ones(4, 1)
3761
3762        self.assertEqual(fn(z), fn_opt(z))
3763
3764    def test_is_vmapped_mutated_tensor_tensor(self):
3765        def fn(x):
3766            y = torch.vmap(torch.Tensor.acos_)(x)
3767            return y is x
3768
3769        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3770
3771        z = torch.ones(4, 1)
3772
3773        self.assertEqual(fn(z), fn_opt(z))
3774
3775    def test_is_init_in_compile_vmapped_mutated_tensor_tensor_multi_arg(self):
3776        def fn(y, z):
3777            a = y.clone()
3778            b = z.clone()
3779
3780            def g(a, b):
3781                return a.acos_(), b.acos_()
3782
3783            c, d = torch.vmap(g)(a, b)
3784            return a is c is b is d
3785
3786        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3787
3788        y = torch.ones(4, 2)
3789        z = torch.ones(4, 10)
3790
3791        self.assertEqual(fn(y, z), fn_opt(y, z))
3792        self.assertEqual(fn(y, y), fn_opt(y, y))
3793
3794    def test_in_set_would_fail_broadcast(self):
3795        param = torch.zeros(5)
3796        param2 = torch.zeros(5, 10)
3797
3798        tensor_list = set()
3799        tensor_list.add(param2)
3800        assert param not in tensor_list
3801
3802        def fn(param, param2):
3803            param.add_(1)
3804            tensor_list = set([param2])
3805            return param in tensor_list
3806
3807        cnts = torch._dynamo.testing.CompileCounter()
3808        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3809        self.assertEqual(opt_fn(param, param2), fn(param, param2))
3810        self.assertEqual(cnts.frame_count, 1)
3811        # Test aliased
3812        self.assertEqual(opt_fn(param, param), fn(param, param))
3813        self.assertEqual(cnts.frame_count, 2)  # Recompiles
3814
3815    def test_in_set_inplace(self):
3816        param = torch.zeros(5)
3817        param2 = torch.zeros(5, 10)
3818
3819        tensor_list = set()
3820        tensor_list.add(param2)
3821        assert param not in tensor_list
3822
3823        def fn(param, param2):
3824            y = param.add_(1)  # Tensor method
3825            z = torch.Tensor.add_(y, 1)  # torch function
3826            tensor_list = set([param2])
3827            return y in tensor_list and z in tensor_list
3828
3829        cnts = torch._dynamo.testing.CompileCounter()
3830        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3831        self.assertEqual(opt_fn(param, param2), fn(param, param2))
3832        self.assertEqual(cnts.frame_count, 1)
3833        # Test aliased
3834        self.assertEqual(opt_fn(param, param), fn(param, param))
3835        self.assertEqual(cnts.frame_count, 2)  # Recompiles
3836
3837    def test_reconstructed_name(self):
3838        lst = []
3839
3840        @torch._dynamo.disable
3841        def disallowed(g):
3842            lst.append(g.__name__)
3843
3844        def f():
3845            def g():
3846                return ()
3847
3848            disallowed(g)
3849
3850        f_opt = torch._dynamo
3851        opt_f = torch._dynamo.optimize(backend="eager")(f)
3852        opt_f()
3853        f()
3854        self.assertEqual(len(lst), 2)
3855        self.assertEqual(lst[0], lst[1])
3856
3857    @unittest.skipIf(
3858        sys.version_info < (3, 10),
3859        "zip strict kwargs not implemented for Python < 3.10",
3860    )
3861    def test_zip_strict(self):
3862        def fn(x, ys, zs):
3863            x = x.clone()
3864            for y, z in zip(ys, zs, strict=True):
3865                x += y * z
3866            return x
3867
3868        opt_fn = torch._dynamo.optimize(backend="eager")(fn)
3869        nopython_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
3870
3871        x = torch.ones(3)
3872        ys = [1.0, 2.0, 3.0]
3873        zs = [2.0, 5.0, 8.0]
3874
3875        self.assertEqual(opt_fn(x, ys, zs), fn(x, ys, zs))
3876
3877        # If nopython, should raise UserError
3878        with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
3879            nopython_fn(x, ys[:1], zs)
3880
3881        with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
3882            nopython_fn(x, ys, zs[:1])
3883
3884        # Should cause fallback if allow graph break
3885        with self.assertRaisesRegex(ValueError, "zip()"):
3886            opt_fn(x, ys[:1], zs)
3887
3888        with self.assertRaisesRegex(ValueError, "zip()"):
3889            opt_fn(x, ys, zs[:1])
3890
3891    def test_fn_with_attr(self):
3892        def fn(x):
3893            if fn.pred:
3894                return torch.relu(x * 2)
3895            else:
3896                return torch.abs(x + 3)
3897
3898        t = torch.ones(3)
3899        counter = torch._dynamo.testing.CompileCounter()
3900        fn.pred = True
3901        opt_fn_0 = torch.compile(fullgraph=True, backend=counter)(fn)
3902        self.assertEqual(opt_fn_0(t), fn(t))
3903        self.assertEqual(counter.frame_count, 1)
3904        fn.pred = False
3905        opt_fn_1 = torch.compile(fullgraph=True, backend=counter)(fn)
3906        self.assertEqual(opt_fn_1(t), fn(t))
3907        self.assertEqual(counter.frame_count, 2)
3908
3909    def test_str_handler_for_user_defined_object(self):
3910        """
3911        Confirms handler behaviour for `str` is the same between eager and dynamo.
3912        Compares a user defined object with custom `__str__` method and without.
3913        """
3914
3915        class CustomStr:
3916            def __str__(self):
3917                return "ok"
3918
3919        def foo_custom_str(x):
3920            a = CustomStr()
3921            return x, str(a)
3922
3923        eager_custom_str = foo_custom_str(torch.ones(4))
3924        dynamo_custom_str = torch.compile(foo_custom_str, fullgraph=True)(torch.ones(4))
3925
3926        self.assertEqual(eager_custom_str[1], dynamo_custom_str[1])
3927        self.assertEqual(eager_custom_str[1], "ok")
3928
3929        class DefaultStr:
3930            pass
3931
3932        def foo_default_str(x):
3933            a = DefaultStr()
3934            return x, str(a)
3935
3936        eager_default_str = foo_default_str(torch.ones(4))
3937        dynamo_default_str = torch.compile(foo_default_str, fullgraph=True)(
3938            torch.ones(4)
3939        )
3940
3941        # Check that the tensor output from eager and dynamo modes are the same
3942        self.assertEqual(eager_default_str[0], dynamo_default_str[0])
3943
3944        # Check that the class name (without memory address) is the same in both modes
3945        eager_class_name = eager_default_str[1].split(" object at")[0]
3946        dynamo_class_name = dynamo_default_str[1].split(" object at")[0]
3947        self.assertEqual(eager_class_name, dynamo_class_name)
3948
3949    def test_pybind_object(self):
3950        def fn(x, pybind_obj):
3951            if pybind_obj.result:
3952                return torch.cos(x)
3953            return torch.sin(x)
3954
3955        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3956
3957        pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(True, ["a==1"], 0)
3958        x = torch.randn(4)
3959        self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
3960
3961        pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(False, ["a==1"], 1)
3962        x = torch.randn(4)
3963        self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
3964
3965
3966instantiate_parametrized_tests(FunctionTests)
3967
3968if __name__ == "__main__":
3969    from torch._dynamo.test_case import run_tests
3970
3971    run_tests()
3972